Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
46dda583
Commit
46dda583
authored
Jan 04, 2026
by
shaw
Browse files
Merge PR #146: feat: 提升流式网关稳定性与安全策略强化
parents
27ed042c
f8e7255c
Changes
62
Show whitespace changes
Inline
Side-by-side
backend/internal/repository/pricing_service_test.go
View file @
46dda583
...
@@ -6,6 +6,7 @@ import (
...
@@ -6,6 +6,7 @@ import (
"net/http/httptest"
"net/http/httptest"
"testing"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/stretchr/testify/suite"
)
)
...
@@ -19,7 +20,13 @@ type PricingServiceSuite struct {
...
@@ -19,7 +20,13 @@ type PricingServiceSuite struct {
func
(
s
*
PricingServiceSuite
)
SetupTest
()
{
func
(
s
*
PricingServiceSuite
)
SetupTest
()
{
s
.
ctx
=
context
.
Background
()
s
.
ctx
=
context
.
Background
()
client
,
ok
:=
NewPricingRemoteClient
()
.
(
*
pricingRemoteClient
)
client
,
ok
:=
NewPricingRemoteClient
(
&
config
.
Config
{
Security
:
config
.
SecurityConfig
{
URLAllowlist
:
config
.
URLAllowlistConfig
{
AllowPrivateHosts
:
true
,
},
},
})
.
(
*
pricingRemoteClient
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
require
.
True
(
s
.
T
(),
ok
,
"type assertion failed"
)
s
.
client
=
client
s
.
client
=
client
}
}
...
...
backend/internal/repository/proxy_probe_service.go
View file @
46dda583
...
@@ -5,28 +5,48 @@ import (
...
@@ -5,28 +5,48 @@ import (
"encoding/json"
"encoding/json"
"fmt"
"fmt"
"io"
"io"
"log"
"net/http"
"net/http"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/service"
)
)
func
NewProxyExitInfoProber
()
service
.
ProxyExitInfoProber
{
func
NewProxyExitInfoProber
(
cfg
*
config
.
Config
)
service
.
ProxyExitInfoProber
{
return
&
proxyProbeService
{
ipInfoURL
:
defaultIPInfoURL
}
insecure
:=
false
allowPrivate
:=
false
if
cfg
!=
nil
{
insecure
=
cfg
.
Security
.
ProxyProbe
.
InsecureSkipVerify
allowPrivate
=
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
}
if
insecure
{
log
.
Printf
(
"[ProxyProbe] Warning: TLS verification is disabled for proxy probing."
)
}
return
&
proxyProbeService
{
ipInfoURL
:
defaultIPInfoURL
,
insecureSkipVerify
:
insecure
,
allowPrivateHosts
:
allowPrivate
,
}
}
}
const
defaultIPInfoURL
=
"https://ipinfo.io/json"
const
defaultIPInfoURL
=
"https://ipinfo.io/json"
type
proxyProbeService
struct
{
type
proxyProbeService
struct
{
ipInfoURL
string
ipInfoURL
string
insecureSkipVerify
bool
allowPrivateHosts
bool
}
}
func
(
s
*
proxyProbeService
)
ProbeProxy
(
ctx
context
.
Context
,
proxyURL
string
)
(
*
service
.
ProxyExitInfo
,
int64
,
error
)
{
func
(
s
*
proxyProbeService
)
ProbeProxy
(
ctx
context
.
Context
,
proxyURL
string
)
(
*
service
.
ProxyExitInfo
,
int64
,
error
)
{
client
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
client
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
ProxyURL
:
proxyURL
,
ProxyURL
:
proxyURL
,
Timeout
:
15
*
time
.
Second
,
Timeout
:
15
*
time
.
Second
,
InsecureSkipVerify
:
true
,
InsecureSkipVerify
:
s
.
insecureSkipVerify
,
ProxyStrict
:
true
,
ValidateResolvedIP
:
true
,
AllowPrivateHosts
:
s
.
allowPrivateHosts
,
})
})
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
0
,
fmt
.
Errorf
(
"failed to create proxy client: %w"
,
err
)
return
nil
,
0
,
fmt
.
Errorf
(
"failed to create proxy client: %w"
,
err
)
...
...
backend/internal/repository/proxy_probe_service_test.go
View file @
46dda583
...
@@ -20,7 +20,10 @@ type ProxyProbeServiceSuite struct {
...
@@ -20,7 +20,10 @@ type ProxyProbeServiceSuite struct {
func
(
s
*
ProxyProbeServiceSuite
)
SetupTest
()
{
func
(
s
*
ProxyProbeServiceSuite
)
SetupTest
()
{
s
.
ctx
=
context
.
Background
()
s
.
ctx
=
context
.
Background
()
s
.
prober
=
&
proxyProbeService
{
ipInfoURL
:
"http://ipinfo.test/json"
}
s
.
prober
=
&
proxyProbeService
{
ipInfoURL
:
"http://ipinfo.test/json"
,
allowPrivateHosts
:
true
,
}
}
}
func
(
s
*
ProxyProbeServiceSuite
)
TearDownTest
()
{
func
(
s
*
ProxyProbeServiceSuite
)
TearDownTest
()
{
...
...
backend/internal/repository/turnstile_service.go
View file @
46dda583
...
@@ -23,6 +23,7 @@ type turnstileVerifier struct {
...
@@ -23,6 +23,7 @@ type turnstileVerifier struct {
func
NewTurnstileVerifier
()
service
.
TurnstileVerifier
{
func
NewTurnstileVerifier
()
service
.
TurnstileVerifier
{
sharedClient
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
sharedClient
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
Timeout
:
10
*
time
.
Second
,
Timeout
:
10
*
time
.
Second
,
ValidateResolvedIP
:
true
,
})
})
if
err
!=
nil
{
if
err
!=
nil
{
sharedClient
=
&
http
.
Client
{
Timeout
:
10
*
time
.
Second
}
sharedClient
=
&
http
.
Client
{
Timeout
:
10
*
time
.
Second
}
...
...
backend/internal/server/api_contract_test.go
View file @
46dda583
...
@@ -294,13 +294,13 @@ func TestAPIContracts(t *testing.T) {
...
@@ -294,13 +294,13 @@ func TestAPIContracts(t *testing.T) {
"smtp_host": "smtp.example.com",
"smtp_host": "smtp.example.com",
"smtp_port": 587,
"smtp_port": 587,
"smtp_username": "user",
"smtp_username": "user",
"smtp_password
": "secret"
,
"smtp_password
_configured": true
,
"smtp_from_email": "no-reply@example.com",
"smtp_from_email": "no-reply@example.com",
"smtp_from_name": "Sub2API",
"smtp_from_name": "Sub2API",
"smtp_use_tls": true,
"smtp_use_tls": true,
"turnstile_enabled": true,
"turnstile_enabled": true,
"turnstile_site_key": "site-key",
"turnstile_site_key": "site-key",
"turnstile_secret_key
": "secret-key"
,
"turnstile_secret_key
_configured": true
,
"site_name": "Sub2API",
"site_name": "Sub2API",
"site_logo": "",
"site_logo": "",
"site_subtitle": "Subtitle",
"site_subtitle": "Subtitle",
...
...
backend/internal/server/http.go
View file @
46dda583
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
package
server
package
server
import
(
import
(
"log"
"net/http"
"net/http"
"time"
"time"
...
@@ -36,6 +37,15 @@ func ProvideRouter(
...
@@ -36,6 +37,15 @@ func ProvideRouter(
r
:=
gin
.
New
()
r
:=
gin
.
New
()
r
.
Use
(
middleware2
.
Recovery
())
r
.
Use
(
middleware2
.
Recovery
())
if
len
(
cfg
.
Server
.
TrustedProxies
)
>
0
{
if
err
:=
r
.
SetTrustedProxies
(
cfg
.
Server
.
TrustedProxies
);
err
!=
nil
{
log
.
Printf
(
"Failed to set trusted proxies: %v"
,
err
)
}
}
else
{
if
err
:=
r
.
SetTrustedProxies
(
nil
);
err
!=
nil
{
log
.
Printf
(
"Failed to disable trusted proxies: %v"
,
err
)
}
}
return
SetupRouter
(
r
,
handlers
,
jwtAuth
,
adminAuth
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
cfg
)
return
SetupRouter
(
r
,
handlers
,
jwtAuth
,
adminAuth
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
cfg
)
}
}
...
...
backend/internal/server/middleware/api_key_auth.go
View file @
46dda583
...
@@ -19,6 +19,13 @@ func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionS
...
@@ -19,6 +19,13 @@ func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionS
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
func
apiKeyAuthWithSubscription
(
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
)
gin
.
HandlerFunc
{
func
apiKeyAuthWithSubscription
(
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
)
gin
.
HandlerFunc
{
return
func
(
c
*
gin
.
Context
)
{
return
func
(
c
*
gin
.
Context
)
{
queryKey
:=
strings
.
TrimSpace
(
c
.
Query
(
"key"
))
queryApiKey
:=
strings
.
TrimSpace
(
c
.
Query
(
"api_key"
))
if
queryKey
!=
""
||
queryApiKey
!=
""
{
AbortWithError
(
c
,
400
,
"api_key_in_query_deprecated"
,
"API key in query parameter is deprecated. Please use Authorization header instead."
)
return
}
// 尝试从Authorization header中提取API key (Bearer scheme)
// 尝试从Authorization header中提取API key (Bearer scheme)
authHeader
:=
c
.
GetHeader
(
"Authorization"
)
authHeader
:=
c
.
GetHeader
(
"Authorization"
)
var
apiKeyString
string
var
apiKeyString
string
...
@@ -41,19 +48,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
...
@@ -41,19 +48,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
apiKeyString
=
c
.
GetHeader
(
"x-goog-api-key"
)
apiKeyString
=
c
.
GetHeader
(
"x-goog-api-key"
)
}
}
// 如果header中没有,尝试从query参数中提取(Google API key风格)
if
apiKeyString
==
""
{
apiKeyString
=
c
.
Query
(
"key"
)
}
// 兼容常见别名
if
apiKeyString
==
""
{
apiKeyString
=
c
.
Query
(
"api_key"
)
}
// 如果所有header都没有API key
// 如果所有header都没有API key
if
apiKeyString
==
""
{
if
apiKeyString
==
""
{
AbortWithError
(
c
,
401
,
"API_KEY_REQUIRED"
,
"API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header
, or key/api_key query parameter
"
)
AbortWithError
(
c
,
401
,
"API_KEY_REQUIRED"
,
"API key is required in Authorization header (Bearer scheme), x-api-key header,
or
x-goog-api-key header"
)
return
return
}
}
...
...
backend/internal/server/middleware/api_key_auth_google.go
View file @
46dda583
...
@@ -22,6 +22,10 @@ func APIKeyAuthGoogle(apiKeyService *service.APIKeyService, cfg *config.Config)
...
@@ -22,6 +22,10 @@ func APIKeyAuthGoogle(apiKeyService *service.APIKeyService, cfg *config.Config)
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
func
APIKeyAuthWithSubscriptionGoogle
(
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
)
gin
.
HandlerFunc
{
func
APIKeyAuthWithSubscriptionGoogle
(
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
cfg
*
config
.
Config
)
gin
.
HandlerFunc
{
return
func
(
c
*
gin
.
Context
)
{
return
func
(
c
*
gin
.
Context
)
{
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"api_key"
));
v
!=
""
{
abortWithGoogleError
(
c
,
400
,
"Query parameter api_key is deprecated. Use Authorization header or key instead."
)
return
}
apiKeyString
:=
extractAPIKeyFromRequest
(
c
)
apiKeyString
:=
extractAPIKeyFromRequest
(
c
)
if
apiKeyString
==
""
{
if
apiKeyString
==
""
{
abortWithGoogleError
(
c
,
401
,
"API key is required"
)
abortWithGoogleError
(
c
,
401
,
"API key is required"
)
...
@@ -116,15 +120,18 @@ func extractAPIKeyFromRequest(c *gin.Context) string {
...
@@ -116,15 +120,18 @@ func extractAPIKeyFromRequest(c *gin.Context) string {
if
v
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"x-goog-api-key"
));
v
!=
""
{
if
v
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"x-goog-api-key"
));
v
!=
""
{
return
v
return
v
}
}
if
allowGoogleQueryKey
(
c
.
Request
.
URL
.
Path
)
{
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"key"
));
v
!=
""
{
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"key"
));
v
!=
""
{
return
v
return
v
}
}
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"api_key"
));
v
!=
""
{
return
v
}
}
return
""
return
""
}
}
func
allowGoogleQueryKey
(
path
string
)
bool
{
return
strings
.
HasPrefix
(
path
,
"/v1beta"
)
||
strings
.
HasPrefix
(
path
,
"/antigravity/v1beta"
)
}
func
abortWithGoogleError
(
c
*
gin
.
Context
,
status
int
,
message
string
)
{
func
abortWithGoogleError
(
c
*
gin
.
Context
,
status
int
,
message
string
)
{
c
.
JSON
(
status
,
gin
.
H
{
c
.
JSON
(
status
,
gin
.
H
{
"error"
:
gin
.
H
{
"error"
:
gin
.
H
{
...
...
backend/internal/server/middleware/api_key_auth_google_test.go
View file @
46dda583
...
@@ -109,6 +109,58 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
...
@@ -109,6 +109,58 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
require
.
Equal
(
t
,
"UNAUTHENTICATED"
,
resp
.
Error
.
Status
)
require
.
Equal
(
t
,
"UNAUTHENTICATED"
,
resp
.
Error
.
Status
)
}
}
func
TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
r
:=
gin
.
New
()
apiKeyService
:=
newTestAPIKeyService
(
fakeAPIKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
return
nil
,
errors
.
New
(
"should not be called"
)
},
})
r
.
Use
(
APIKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
nil
,
&
config
.
Config
{}))
r
.
GET
(
"/v1beta/test"
,
func
(
c
*
gin
.
Context
)
{
c
.
JSON
(
200
,
gin
.
H
{
"ok"
:
true
})
})
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/v1beta/test?api_key=legacy"
,
nil
)
rec
:=
httptest
.
NewRecorder
()
r
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
var
resp
googleErrorResponse
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
resp
.
Error
.
Code
)
require
.
Equal
(
t
,
"Query parameter api_key is deprecated. Use Authorization header or key instead."
,
resp
.
Error
.
Message
)
require
.
Equal
(
t
,
"INVALID_ARGUMENT"
,
resp
.
Error
.
Status
)
}
func
TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
r
:=
gin
.
New
()
apiKeyService
:=
newTestAPIKeyService
(
fakeAPIKeyRepo
{
getByKey
:
func
(
ctx
context
.
Context
,
key
string
)
(
*
service
.
APIKey
,
error
)
{
return
&
service
.
APIKey
{
ID
:
1
,
Key
:
key
,
Status
:
service
.
StatusActive
,
User
:
&
service
.
User
{
ID
:
123
,
Status
:
service
.
StatusActive
,
},
},
nil
},
})
cfg
:=
&
config
.
Config
{
RunMode
:
config
.
RunModeSimple
}
r
.
Use
(
APIKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
nil
,
cfg
))
r
.
GET
(
"/v1beta/test"
,
func
(
c
*
gin
.
Context
)
{
c
.
JSON
(
200
,
gin
.
H
{
"ok"
:
true
})
})
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/v1beta/test?key=valid"
,
nil
)
rec
:=
httptest
.
NewRecorder
()
r
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
}
func
TestApiKeyAuthWithSubscriptionGoogle_InvalidKey
(
t
*
testing
.
T
)
{
func
TestApiKeyAuthWithSubscriptionGoogle_InvalidKey
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
gin
.
SetMode
(
gin
.
TestMode
)
...
...
backend/internal/server/middleware/cors.go
View file @
46dda583
package
middleware
package
middleware
import
(
import
(
"log"
"net/http"
"strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin"
)
)
var
corsWarningOnce
sync
.
Once
// CORS 跨域中间件
// CORS 跨域中间件
func
CORS
()
gin
.
HandlerFunc
{
func
CORS
(
cfg
config
.
CORSConfig
)
gin
.
HandlerFunc
{
allowedOrigins
:=
normalizeOrigins
(
cfg
.
AllowedOrigins
)
allowAll
:=
false
for
_
,
origin
:=
range
allowedOrigins
{
if
origin
==
"*"
{
allowAll
=
true
break
}
}
wildcardWithSpecific
:=
allowAll
&&
len
(
allowedOrigins
)
>
1
if
wildcardWithSpecific
{
allowedOrigins
=
[]
string
{
"*"
}
}
allowCredentials
:=
cfg
.
AllowCredentials
corsWarningOnce
.
Do
(
func
()
{
if
len
(
allowedOrigins
)
==
0
{
log
.
Println
(
"Warning: CORS allowed_origins not configured; cross-origin requests will be rejected."
)
}
if
wildcardWithSpecific
{
log
.
Println
(
"Warning: CORS allowed_origins includes '*'; wildcard will take precedence over explicit origins."
)
}
if
allowAll
&&
allowCredentials
{
log
.
Println
(
"Warning: CORS allowed_origins set to '*', disabling allow_credentials."
)
}
})
if
allowAll
&&
allowCredentials
{
allowCredentials
=
false
}
allowedSet
:=
make
(
map
[
string
]
struct
{},
len
(
allowedOrigins
))
for
_
,
origin
:=
range
allowedOrigins
{
if
origin
==
""
||
origin
==
"*"
{
continue
}
allowedSet
[
origin
]
=
struct
{}{}
}
return
func
(
c
*
gin
.
Context
)
{
return
func
(
c
*
gin
.
Context
)
{
// 设置允许跨域的响应头
origin
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"Origin"
))
originAllowed
:=
allowAll
if
origin
!=
""
&&
!
allowAll
{
_
,
originAllowed
=
allowedSet
[
origin
]
}
if
originAllowed
{
if
allowAll
{
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Origin"
,
"*"
)
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Origin"
,
"*"
)
}
else
if
origin
!=
""
{
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Origin"
,
origin
)
c
.
Writer
.
Header
()
.
Add
(
"Vary"
,
"Origin"
)
}
if
allowCredentials
{
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Credentials"
,
"true"
)
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Credentials"
,
"true"
)
}
}
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Headers"
,
"Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key"
)
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Headers"
,
"Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key"
)
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Methods"
,
"POST, OPTIONS, GET, PUT, DELETE, PATCH"
)
c
.
Writer
.
Header
()
.
Set
(
"Access-Control-Allow-Methods"
,
"POST, OPTIONS, GET, PUT, DELETE, PATCH"
)
// 处理预检请求
// 处理预检请求
if
c
.
Request
.
Method
==
"OPTIONS"
{
if
c
.
Request
.
Method
==
http
.
MethodOptions
{
c
.
AbortWithStatus
(
204
)
if
originAllowed
{
c
.
AbortWithStatus
(
http
.
StatusNoContent
)
}
else
{
c
.
AbortWithStatus
(
http
.
StatusForbidden
)
}
return
return
}
}
c
.
Next
()
c
.
Next
()
}
}
}
}
func
normalizeOrigins
(
values
[]
string
)
[]
string
{
if
len
(
values
)
==
0
{
return
nil
}
normalized
:=
make
([]
string
,
0
,
len
(
values
))
for
_
,
value
:=
range
values
{
trimmed
:=
strings
.
TrimSpace
(
value
)
if
trimmed
==
""
{
continue
}
normalized
=
append
(
normalized
,
trimmed
)
}
return
normalized
}
backend/internal/server/middleware/security_headers.go
0 → 100644
View file @
46dda583
package
middleware
import
(
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
// SecurityHeaders sets baseline security headers for all responses.
func
SecurityHeaders
(
cfg
config
.
CSPConfig
)
gin
.
HandlerFunc
{
policy
:=
strings
.
TrimSpace
(
cfg
.
Policy
)
if
policy
==
""
{
policy
=
config
.
DefaultCSPPolicy
}
return
func
(
c
*
gin
.
Context
)
{
c
.
Header
(
"X-Content-Type-Options"
,
"nosniff"
)
c
.
Header
(
"X-Frame-Options"
,
"DENY"
)
c
.
Header
(
"Referrer-Policy"
,
"strict-origin-when-cross-origin"
)
if
cfg
.
Enabled
{
c
.
Header
(
"Content-Security-Policy"
,
policy
)
}
c
.
Next
()
}
}
backend/internal/server/router.go
View file @
46dda583
...
@@ -24,7 +24,8 @@ func SetupRouter(
...
@@ -24,7 +24,8 @@ func SetupRouter(
)
*
gin
.
Engine
{
)
*
gin
.
Engine
{
// 应用中间件
// 应用中间件
r
.
Use
(
middleware2
.
Logger
())
r
.
Use
(
middleware2
.
Logger
())
r
.
Use
(
middleware2
.
CORS
())
r
.
Use
(
middleware2
.
CORS
(
cfg
.
CORS
))
r
.
Use
(
middleware2
.
SecurityHeaders
(
cfg
.
Security
.
CSP
))
// Serve embedded frontend if available
// Serve embedded frontend if available
if
web
.
HasEmbeddedFrontend
()
{
if
web
.
HasEmbeddedFrontend
()
{
...
...
backend/internal/service/account_test_service.go
View file @
46dda583
...
@@ -7,6 +7,7 @@ import (
...
@@ -7,6 +7,7 @@ import (
"crypto/rand"
"crypto/rand"
"encoding/hex"
"encoding/hex"
"encoding/json"
"encoding/json"
"errors"
"fmt"
"fmt"
"io"
"io"
"log"
"log"
...
@@ -14,9 +15,11 @@ import (
...
@@ -14,9 +15,11 @@ import (
"regexp"
"regexp"
"strings"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/google/uuid"
)
)
...
@@ -45,6 +48,7 @@ type AccountTestService struct {
...
@@ -45,6 +48,7 @@ type AccountTestService struct {
geminiTokenProvider
*
GeminiTokenProvider
geminiTokenProvider
*
GeminiTokenProvider
antigravityGatewayService
*
AntigravityGatewayService
antigravityGatewayService
*
AntigravityGatewayService
httpUpstream
HTTPUpstream
httpUpstream
HTTPUpstream
cfg
*
config
.
Config
}
}
// NewAccountTestService creates a new AccountTestService
// NewAccountTestService creates a new AccountTestService
...
@@ -53,15 +57,32 @@ func NewAccountTestService(
...
@@ -53,15 +57,32 @@ func NewAccountTestService(
geminiTokenProvider
*
GeminiTokenProvider
,
geminiTokenProvider
*
GeminiTokenProvider
,
antigravityGatewayService
*
AntigravityGatewayService
,
antigravityGatewayService
*
AntigravityGatewayService
,
httpUpstream
HTTPUpstream
,
httpUpstream
HTTPUpstream
,
cfg
*
config
.
Config
,
)
*
AccountTestService
{
)
*
AccountTestService
{
return
&
AccountTestService
{
return
&
AccountTestService
{
accountRepo
:
accountRepo
,
accountRepo
:
accountRepo
,
geminiTokenProvider
:
geminiTokenProvider
,
geminiTokenProvider
:
geminiTokenProvider
,
antigravityGatewayService
:
antigravityGatewayService
,
antigravityGatewayService
:
antigravityGatewayService
,
httpUpstream
:
httpUpstream
,
httpUpstream
:
httpUpstream
,
cfg
:
cfg
,
}
}
}
}
func
(
s
*
AccountTestService
)
validateUpstreamBaseURL
(
raw
string
)
(
string
,
error
)
{
if
s
.
cfg
==
nil
{
return
""
,
errors
.
New
(
"config is not available"
)
}
normalized
,
err
:=
urlvalidator
.
ValidateHTTPSURL
(
raw
,
urlvalidator
.
ValidationOptions
{
AllowedHosts
:
s
.
cfg
.
Security
.
URLAllowlist
.
UpstreamHosts
,
RequireAllowlist
:
true
,
AllowPrivate
:
s
.
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
,
})
if
err
!=
nil
{
return
""
,
err
}
return
normalized
,
nil
}
// generateSessionString generates a Claude Code style session string
// generateSessionString generates a Claude Code style session string
func
generateSessionString
()
(
string
,
error
)
{
func
generateSessionString
()
(
string
,
error
)
{
bytes
:=
make
([]
byte
,
32
)
bytes
:=
make
([]
byte
,
32
)
...
@@ -183,11 +204,15 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
...
@@ -183,11 +204,15 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
return
s
.
sendErrorAndEnd
(
c
,
"No API key available"
)
return
s
.
sendErrorAndEnd
(
c
,
"No API key available"
)
}
}
api
URL
=
account
.
GetBaseURL
()
base
URL
:
=
account
.
GetBaseURL
()
if
api
URL
==
""
{
if
base
URL
==
""
{
api
URL
=
"https://api.anthropic.com"
base
URL
=
"https://api.anthropic.com"
}
}
apiURL
=
strings
.
TrimSuffix
(
apiURL
,
"/"
)
+
"/v1/messages"
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Invalid base URL: %s"
,
err
.
Error
()))
}
apiURL
=
strings
.
TrimSuffix
(
normalizedBaseURL
,
"/"
)
+
"/v1/messages"
}
else
{
}
else
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Unsupported account type: %s"
,
account
.
Type
))
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Unsupported account type: %s"
,
account
.
Type
))
}
}
...
@@ -300,7 +325,11 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
...
@@ -300,7 +325,11 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
if
baseURL
==
""
{
if
baseURL
==
""
{
baseURL
=
"https://api.openai.com"
baseURL
=
"https://api.openai.com"
}
}
apiURL
=
strings
.
TrimSuffix
(
baseURL
,
"/"
)
+
"/responses"
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Invalid base URL: %s"
,
err
.
Error
()))
}
apiURL
=
strings
.
TrimSuffix
(
normalizedBaseURL
,
"/"
)
+
"/responses"
}
else
{
}
else
{
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Unsupported account type: %s"
,
account
.
Type
))
return
s
.
sendErrorAndEnd
(
c
,
fmt
.
Sprintf
(
"Unsupported account type: %s"
,
account
.
Type
))
}
}
...
@@ -480,10 +509,14 @@ func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, accou
...
@@ -480,10 +509,14 @@ func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, accou
if
baseURL
==
""
{
if
baseURL
==
""
{
baseURL
=
geminicli
.
AIStudioBaseURL
baseURL
=
geminicli
.
AIStudioBaseURL
}
}
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
err
}
// Use streamGenerateContent for real-time feedback
// Use streamGenerateContent for real-time feedback
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:streamGenerateContent?alt=sse"
,
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:streamGenerateContent?alt=sse"
,
strings
.
TrimRight
(
b
aseURL
,
"/"
),
modelID
)
strings
.
TrimRight
(
normalizedB
aseURL
,
"/"
),
modelID
)
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
"POST"
,
fullURL
,
bytes
.
NewReader
(
payload
))
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
"POST"
,
fullURL
,
bytes
.
NewReader
(
payload
))
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -515,7 +548,11 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun
...
@@ -515,7 +548,11 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun
if
strings
.
TrimSpace
(
baseURL
)
==
""
{
if
strings
.
TrimSpace
(
baseURL
)
==
""
{
baseURL
=
geminicli
.
AIStudioBaseURL
baseURL
=
geminicli
.
AIStudioBaseURL
}
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:streamGenerateContent?alt=sse"
,
strings
.
TrimRight
(
baseURL
,
"/"
),
modelID
)
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
err
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:streamGenerateContent?alt=sse"
,
strings
.
TrimRight
(
normalizedBaseURL
,
"/"
),
modelID
)
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
fullURL
,
bytes
.
NewReader
(
payload
))
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
fullURL
,
bytes
.
NewReader
(
payload
))
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -544,7 +581,11 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT
...
@@ -544,7 +581,11 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT
}
}
wrappedBytes
,
_
:=
json
.
Marshal
(
wrapped
)
wrappedBytes
,
_
:=
json
.
Marshal
(
wrapped
)
fullURL
:=
fmt
.
Sprintf
(
"%s/v1internal:streamGenerateContent?alt=sse"
,
geminicli
.
GeminiCliBaseURL
)
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
geminicli
.
GeminiCliBaseURL
)
if
err
!=
nil
{
return
nil
,
err
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1internal:streamGenerateContent?alt=sse"
,
normalizedBaseURL
)
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
"POST"
,
fullURL
,
bytes
.
NewReader
(
wrappedBytes
))
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
"POST"
,
fullURL
,
bytes
.
NewReader
(
wrappedBytes
))
if
err
!=
nil
{
if
err
!=
nil
{
...
...
backend/internal/service/antigravity_gateway_service.go
View file @
46dda583
...
@@ -11,6 +11,7 @@ import (
...
@@ -11,6 +11,7 @@ import (
"log"
"log"
"net/http"
"net/http"
"strings"
"strings"
"sync/atomic"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
...
@@ -916,20 +917,102 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
...
@@ -916,20 +917,102 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
return
nil
,
errors
.
New
(
"streaming not supported"
)
return
nil
,
errors
.
New
(
"streaming not supported"
)
}
}
reader
:=
bufio
.
NewReader
(
resp
.
Body
)
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
maxLineSize
:=
defaultMaxLineSize
if
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
MaxLineSize
>
0
{
maxLineSize
=
s
.
settingService
.
cfg
.
Gateway
.
MaxLineSize
}
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
maxLineSize
)
usage
:=
&
ClaudeUsage
{}
usage
:=
&
ClaudeUsage
{}
var
firstTokenMs
*
int
var
firstTokenMs
*
int
type
scanEvent
struct
{
line
string
err
error
}
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
events
:=
make
(
chan
scanEvent
,
16
)
done
:=
make
(
chan
struct
{})
sendEvent
:=
func
(
ev
scanEvent
)
bool
{
select
{
case
events
<-
ev
:
return
true
case
<-
done
:
return
false
}
}
var
lastReadAt
int64
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
go
func
()
{
defer
close
(
events
)
for
scanner
.
Scan
()
{
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
if
!
sendEvent
(
scanEvent
{
line
:
scanner
.
Text
()})
{
return
}
}
if
err
:=
scanner
.
Err
();
err
!=
nil
{
_
=
sendEvent
(
scanEvent
{
err
:
err
})
}
}()
defer
close
(
done
)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
streamInterval
:=
time
.
Duration
(
0
)
if
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
>
0
{
streamInterval
=
time
.
Duration
(
s
.
settingService
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
)
*
time
.
Second
}
var
intervalTicker
*
time
.
Ticker
if
streamInterval
>
0
{
intervalTicker
=
time
.
NewTicker
(
streamInterval
)
defer
intervalTicker
.
Stop
()
}
var
intervalCh
<-
chan
time
.
Time
if
intervalTicker
!=
nil
{
intervalCh
=
intervalTicker
.
C
}
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent
:=
false
sendErrorEvent
:=
func
(
reason
string
)
{
if
errorEventSent
{
return
}
errorEventSent
=
true
_
,
_
=
fmt
.
Fprintf
(
c
.
Writer
,
"event: error
\n
data: {
\"
error
\"
:
\"
%s
\"
}
\n\n
"
,
reason
)
flusher
.
Flush
()
}
for
{
for
{
line
,
err
:=
reader
.
ReadString
(
'\n'
)
select
{
if
len
(
line
)
>
0
{
case
ev
,
ok
:=
<-
events
:
if
!
ok
{
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
if
ev
.
err
!=
nil
{
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
log
.
Printf
(
"SSE line too long (antigravity): max_size=%d error=%v"
,
maxLineSize
,
ev
.
err
)
sendErrorEvent
(
"response_too_large"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
ev
.
err
}
sendErrorEvent
(
"stream_read_error"
)
return
nil
,
ev
.
err
}
line
:=
ev
.
line
trimmed
:=
strings
.
TrimRight
(
line
,
"
\r\n
"
)
trimmed
:=
strings
.
TrimRight
(
line
,
"
\r\n
"
)
if
strings
.
HasPrefix
(
trimmed
,
"data:"
)
{
if
strings
.
HasPrefix
(
trimmed
,
"data:"
)
{
payload
:=
strings
.
TrimSpace
(
strings
.
TrimPrefix
(
trimmed
,
"data:"
))
payload
:=
strings
.
TrimSpace
(
strings
.
TrimPrefix
(
trimmed
,
"data:"
))
if
payload
==
""
||
payload
==
"[DONE]"
{
if
payload
==
""
||
payload
==
"[DONE]"
{
_
,
_
=
io
.
WriteString
(
c
.
Writer
,
line
)
if
_
,
err
:=
fmt
.
Fprintf
(
c
.
Writer
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
flusher
.
Flush
()
flusher
.
Flush
()
}
else
{
continue
}
// 解包 v1internal 响应
// 解包 v1internal 响应
inner
,
parseErr
:=
s
.
unwrapV1InternalResponse
([]
byte
(
payload
))
inner
,
parseErr
:=
s
.
unwrapV1InternalResponse
([]
byte
(
payload
))
if
parseErr
==
nil
&&
inner
!=
nil
{
if
parseErr
==
nil
&&
inner
!=
nil
{
...
@@ -949,24 +1032,30 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
...
@@ -949,24 +1032,30 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
firstTokenMs
=
&
ms
firstTokenMs
=
&
ms
}
}
_
,
_
=
fmt
.
Fprintf
(
c
.
Writer
,
"data: %s
\n\n
"
,
payload
)
if
_
,
err
:=
fmt
.
Fprintf
(
c
.
Writer
,
"data: %s
\n\n
"
,
payload
);
err
!=
nil
{
flusher
.
Flush
()
sendErrorEvent
(
"write_failed"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
}
}
else
{
_
,
_
=
io
.
WriteString
(
c
.
Writer
,
line
)
flusher
.
Flush
()
flusher
.
Flush
()
continue
}
}
if
_
,
err
:=
fmt
.
Fprintf
(
c
.
Writer
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
}
flusher
.
Flush
()
if
errors
.
Is
(
err
,
io
.
EOF
)
{
case
<-
intervalCh
:
break
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
if
time
.
Since
(
lastRead
)
<
streamInterval
{
continue
}
}
if
err
!=
nil
{
log
.
Printf
(
"Stream data interval timeout (antigravity)"
)
return
nil
,
err
sendErrorEvent
(
"stream_timeout"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
}
}
}
}
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
}
func
(
s
*
AntigravityGatewayService
)
handleGeminiNonStreamingResponse
(
c
*
gin
.
Context
,
resp
*
http
.
Response
)
(
*
ClaudeUsage
,
error
)
{
func
(
s
*
AntigravityGatewayService
)
handleGeminiNonStreamingResponse
(
c
*
gin
.
Context
,
resp
*
http
.
Response
)
(
*
ClaudeUsage
,
error
)
{
...
@@ -1105,7 +1194,13 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
...
@@ -1105,7 +1194,13 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
processor
:=
antigravity
.
NewStreamingProcessor
(
originalModel
)
processor
:=
antigravity
.
NewStreamingProcessor
(
originalModel
)
var
firstTokenMs
*
int
var
firstTokenMs
*
int
reader
:=
bufio
.
NewReader
(
resp
.
Body
)
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
maxLineSize
:=
defaultMaxLineSize
if
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
MaxLineSize
>
0
{
maxLineSize
=
s
.
settingService
.
cfg
.
Gateway
.
MaxLineSize
}
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
maxLineSize
)
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
convertUsage
:=
func
(
agUsage
*
antigravity
.
ClaudeUsage
)
*
ClaudeUsage
{
convertUsage
:=
func
(
agUsage
*
antigravity
.
ClaudeUsage
)
*
ClaudeUsage
{
...
@@ -1120,13 +1215,85 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
...
@@ -1120,13 +1215,85 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
}
}
}
}
type
scanEvent
struct
{
line
string
err
error
}
// 独立 goroutine 读取上游,避免读取阻塞影响超时处理
events
:=
make
(
chan
scanEvent
,
16
)
done
:=
make
(
chan
struct
{})
sendEvent
:=
func
(
ev
scanEvent
)
bool
{
select
{
case
events
<-
ev
:
return
true
case
<-
done
:
return
false
}
}
var
lastReadAt
int64
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
go
func
()
{
defer
close
(
events
)
for
scanner
.
Scan
()
{
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
if
!
sendEvent
(
scanEvent
{
line
:
scanner
.
Text
()})
{
return
}
}
if
err
:=
scanner
.
Err
();
err
!=
nil
{
_
=
sendEvent
(
scanEvent
{
err
:
err
})
}
}()
defer
close
(
done
)
streamInterval
:=
time
.
Duration
(
0
)
if
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
>
0
{
streamInterval
=
time
.
Duration
(
s
.
settingService
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
)
*
time
.
Second
}
var
intervalTicker
*
time
.
Ticker
if
streamInterval
>
0
{
intervalTicker
=
time
.
NewTicker
(
streamInterval
)
defer
intervalTicker
.
Stop
()
}
var
intervalCh
<-
chan
time
.
Time
if
intervalTicker
!=
nil
{
intervalCh
=
intervalTicker
.
C
}
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent
:=
false
sendErrorEvent
:=
func
(
reason
string
)
{
if
errorEventSent
{
return
}
errorEventSent
=
true
_
,
_
=
fmt
.
Fprintf
(
c
.
Writer
,
"event: error
\n
data: {
\"
error
\"
:
\"
%s
\"
}
\n\n
"
,
reason
)
flusher
.
Flush
()
}
for
{
for
{
line
,
err
:=
reader
.
ReadString
(
'\n'
)
select
{
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
io
.
EOF
)
{
case
ev
,
ok
:=
<-
events
:
return
nil
,
fmt
.
Errorf
(
"stream read error: %w"
,
err
)
if
!
ok
{
// 发送结束事件
finalEvents
,
agUsage
:=
processor
.
Finish
()
if
len
(
finalEvents
)
>
0
{
_
,
_
=
c
.
Writer
.
Write
(
finalEvents
)
flusher
.
Flush
()
}
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
agUsage
),
firstTokenMs
:
firstTokenMs
},
nil
}
if
ev
.
err
!=
nil
{
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
log
.
Printf
(
"SSE line too long (antigravity): max_size=%d error=%v"
,
maxLineSize
,
ev
.
err
)
sendErrorEvent
(
"response_too_large"
)
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
nil
),
firstTokenMs
:
firstTokenMs
},
ev
.
err
}
sendErrorEvent
(
"stream_read_error"
)
return
nil
,
fmt
.
Errorf
(
"stream read error: %w"
,
ev
.
err
)
}
}
if
len
(
line
)
>
0
{
line
:=
ev
.
line
// 处理 SSE 行,转换为 Claude 格式
// 处理 SSE 行,转换为 Claude 格式
claudeEvents
:=
processor
.
ProcessLine
(
strings
.
TrimRight
(
line
,
"
\r\n
"
))
claudeEvents
:=
processor
.
ProcessLine
(
strings
.
TrimRight
(
line
,
"
\r\n
"
))
...
@@ -1141,23 +1308,21 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
...
@@ -1141,23 +1308,21 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
if
len
(
finalEvents
)
>
0
{
if
len
(
finalEvents
)
>
0
{
_
,
_
=
c
.
Writer
.
Write
(
finalEvents
)
_
,
_
=
c
.
Writer
.
Write
(
finalEvents
)
}
}
sendErrorEvent
(
"write_failed"
)
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
agUsage
),
firstTokenMs
:
firstTokenMs
},
writeErr
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
agUsage
),
firstTokenMs
:
firstTokenMs
},
writeErr
}
}
flusher
.
Flush
()
flusher
.
Flush
()
}
}
}
if
errors
.
Is
(
err
,
io
.
EOF
)
{
case
<-
intervalCh
:
break
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
if
time
.
Since
(
lastRead
)
<
streamInterval
{
continue
}
}
log
.
Printf
(
"Stream data interval timeout (antigravity)"
)
sendErrorEvent
(
"stream_timeout"
)
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
nil
),
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
}
}
// 发送结束事件
finalEvents
,
agUsage
:=
processor
.
Finish
()
if
len
(
finalEvents
)
>
0
{
_
,
_
=
c
.
Writer
.
Write
(
finalEvents
)
flusher
.
Flush
()
}
}
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
agUsage
),
firstTokenMs
:
firstTokenMs
},
nil
}
}
backend/internal/service/auth_service.go
View file @
46dda583
...
@@ -221,9 +221,33 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
...
@@ -221,9 +221,33 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
// VerifyTurnstile 验证Turnstile token
// VerifyTurnstile 验证Turnstile token
func
(
s
*
AuthService
)
VerifyTurnstile
(
ctx
context
.
Context
,
token
string
,
remoteIP
string
)
error
{
func
(
s
*
AuthService
)
VerifyTurnstile
(
ctx
context
.
Context
,
token
string
,
remoteIP
string
)
error
{
required
:=
s
.
cfg
!=
nil
&&
s
.
cfg
.
Server
.
Mode
==
"release"
&&
s
.
cfg
.
Turnstile
.
Required
if
required
{
if
s
.
settingService
==
nil
{
log
.
Println
(
"[Auth] Turnstile required but settings service is not configured"
)
return
ErrTurnstileNotConfigured
}
enabled
:=
s
.
settingService
.
IsTurnstileEnabled
(
ctx
)
secretConfigured
:=
s
.
settingService
.
GetTurnstileSecretKey
(
ctx
)
!=
""
if
!
enabled
||
!
secretConfigured
{
log
.
Printf
(
"[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)"
,
enabled
,
secretConfigured
)
return
ErrTurnstileNotConfigured
}
}
if
s
.
turnstileService
==
nil
{
if
s
.
turnstileService
==
nil
{
if
required
{
log
.
Println
(
"[Auth] Turnstile required but service not configured"
)
return
ErrTurnstileNotConfigured
}
return
nil
// 服务未配置则跳过验证
return
nil
// 服务未配置则跳过验证
}
}
if
!
required
&&
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsTurnstileEnabled
(
ctx
)
&&
s
.
settingService
.
GetTurnstileSecretKey
(
ctx
)
==
""
{
log
.
Println
(
"[Auth] Turnstile enabled but secret key not configured"
)
}
return
s
.
turnstileService
.
VerifyToken
(
ctx
,
token
,
remoteIP
)
return
s
.
turnstileService
.
VerifyToken
(
ctx
,
token
,
remoteIP
)
}
}
...
...
backend/internal/service/billing_cache_service.go
View file @
46dda583
...
@@ -17,6 +17,7 @@ import (
...
@@ -17,6 +17,7 @@ import (
// 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
// 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
var
(
var
(
ErrSubscriptionInvalid
=
infraerrors
.
Forbidden
(
"SUBSCRIPTION_INVALID"
,
"subscription is invalid or expired"
)
ErrSubscriptionInvalid
=
infraerrors
.
Forbidden
(
"SUBSCRIPTION_INVALID"
,
"subscription is invalid or expired"
)
ErrBillingServiceUnavailable
=
infraerrors
.
ServiceUnavailable
(
"BILLING_SERVICE_ERROR"
,
"Billing service temporarily unavailable. Please retry later."
)
)
)
// subscriptionCacheData 订阅缓存数据结构(内部使用)
// subscriptionCacheData 订阅缓存数据结构(内部使用)
...
@@ -76,6 +77,7 @@ type BillingCacheService struct {
...
@@ -76,6 +77,7 @@ type BillingCacheService struct {
userRepo
UserRepository
userRepo
UserRepository
subRepo
UserSubscriptionRepository
subRepo
UserSubscriptionRepository
cfg
*
config
.
Config
cfg
*
config
.
Config
circuitBreaker
*
billingCircuitBreaker
cacheWriteChan
chan
cacheWriteTask
cacheWriteChan
chan
cacheWriteTask
cacheWriteWg
sync
.
WaitGroup
cacheWriteWg
sync
.
WaitGroup
...
@@ -95,6 +97,7 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
...
@@ -95,6 +97,7 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
subRepo
:
subRepo
,
subRepo
:
subRepo
,
cfg
:
cfg
,
cfg
:
cfg
,
}
}
svc
.
circuitBreaker
=
newBillingCircuitBreaker
(
cfg
.
Billing
.
CircuitBreaker
)
svc
.
startCacheWriteWorkers
()
svc
.
startCacheWriteWorkers
()
return
svc
return
svc
}
}
...
@@ -450,6 +453,9 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
...
@@ -450,6 +453,9 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
if
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
if
s
.
cfg
.
RunMode
==
config
.
RunModeSimple
{
return
nil
return
nil
}
}
if
s
.
circuitBreaker
!=
nil
&&
!
s
.
circuitBreaker
.
Allow
()
{
return
ErrBillingServiceUnavailable
}
// 判断计费模式
// 判断计费模式
isSubscriptionMode
:=
group
!=
nil
&&
group
.
IsSubscriptionType
()
&&
subscription
!=
nil
isSubscriptionMode
:=
group
!=
nil
&&
group
.
IsSubscriptionType
()
&&
subscription
!=
nil
...
@@ -465,9 +471,14 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
...
@@ -465,9 +471,14 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
func
(
s
*
BillingCacheService
)
checkBalanceEligibility
(
ctx
context
.
Context
,
userID
int64
)
error
{
func
(
s
*
BillingCacheService
)
checkBalanceEligibility
(
ctx
context
.
Context
,
userID
int64
)
error
{
balance
,
err
:=
s
.
GetUserBalance
(
ctx
,
userID
)
balance
,
err
:=
s
.
GetUserBalance
(
ctx
,
userID
)
if
err
!=
nil
{
if
err
!=
nil
{
// 缓存/数据库错误,允许通过(降级处理)
if
s
.
circuitBreaker
!=
nil
{
log
.
Printf
(
"Warning: get user balance failed, allowing request: %v"
,
err
)
s
.
circuitBreaker
.
OnFailure
(
err
)
return
nil
}
log
.
Printf
(
"ALERT: billing balance check failed for user %d: %v"
,
userID
,
err
)
return
ErrBillingServiceUnavailable
.
WithCause
(
err
)
}
if
s
.
circuitBreaker
!=
nil
{
s
.
circuitBreaker
.
OnSuccess
()
}
}
if
balance
<=
0
{
if
balance
<=
0
{
...
@@ -482,9 +493,14 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
...
@@ -482,9 +493,14 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
// 获取订阅缓存数据
// 获取订阅缓存数据
subData
,
err
:=
s
.
GetSubscriptionStatus
(
ctx
,
userID
,
group
.
ID
)
subData
,
err
:=
s
.
GetSubscriptionStatus
(
ctx
,
userID
,
group
.
ID
)
if
err
!=
nil
{
if
err
!=
nil
{
// 缓存/数据库错误,降级使用传入的subscription进行检查
if
s
.
circuitBreaker
!=
nil
{
log
.
Printf
(
"Warning: get subscription cache failed, using fallback: %v"
,
err
)
s
.
circuitBreaker
.
OnFailure
(
err
)
return
s
.
checkSubscriptionLimitsFallback
(
subscription
,
group
)
}
log
.
Printf
(
"ALERT: billing subscription check failed for user %d group %d: %v"
,
userID
,
group
.
ID
,
err
)
return
ErrBillingServiceUnavailable
.
WithCause
(
err
)
}
if
s
.
circuitBreaker
!=
nil
{
s
.
circuitBreaker
.
OnSuccess
()
}
}
// 检查订阅状态
// 检查订阅状态
...
@@ -513,27 +529,133 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
...
@@ -513,27 +529,133 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
return
nil
return
nil
}
}
// checkSubscriptionLimitsFallback 降级检查订阅限额
type
billingCircuitBreakerState
int
func
(
s
*
BillingCacheService
)
checkSubscriptionLimitsFallback
(
subscription
*
UserSubscription
,
group
*
Group
)
error
{
if
subscription
==
nil
{
const
(
return
ErrSubscriptionInvalid
billingCircuitClosed
billingCircuitBreakerState
=
iota
billingCircuitOpen
billingCircuitHalfOpen
)
type
billingCircuitBreaker
struct
{
mu
sync
.
Mutex
state
billingCircuitBreakerState
failures
int
openedAt
time
.
Time
failureThreshold
int
resetTimeout
time
.
Duration
halfOpenRequests
int
halfOpenRemaining
int
}
func
newBillingCircuitBreaker
(
cfg
config
.
CircuitBreakerConfig
)
*
billingCircuitBreaker
{
if
!
cfg
.
Enabled
{
return
nil
}
resetTimeout
:=
time
.
Duration
(
cfg
.
ResetTimeoutSeconds
)
*
time
.
Second
if
resetTimeout
<=
0
{
resetTimeout
=
30
*
time
.
Second
}
halfOpen
:=
cfg
.
HalfOpenRequests
if
halfOpen
<=
0
{
halfOpen
=
1
}
}
threshold
:=
cfg
.
FailureThreshold
if
threshold
<=
0
{
threshold
=
5
}
return
&
billingCircuitBreaker
{
state
:
billingCircuitClosed
,
failureThreshold
:
threshold
,
resetTimeout
:
resetTimeout
,
halfOpenRequests
:
halfOpen
,
}
}
if
!
subscription
.
IsActive
()
{
func
(
b
*
billingCircuitBreaker
)
Allow
()
bool
{
return
ErrSubscriptionInvalid
b
.
mu
.
Lock
()
defer
b
.
mu
.
Unlock
()
switch
b
.
state
{
case
billingCircuitClosed
:
return
true
case
billingCircuitOpen
:
if
time
.
Since
(
b
.
openedAt
)
<
b
.
resetTimeout
{
return
false
}
b
.
state
=
billingCircuitHalfOpen
b
.
halfOpenRemaining
=
b
.
halfOpenRequests
log
.
Printf
(
"ALERT: billing circuit breaker entering half-open state"
)
fallthrough
case
billingCircuitHalfOpen
:
if
b
.
halfOpenRemaining
<=
0
{
return
false
}
}
b
.
halfOpenRemaining
--
return
true
default
:
return
false
}
}
if
!
subscription
.
CheckDailyLimit
(
group
,
0
)
{
func
(
b
*
billingCircuitBreaker
)
OnFailure
(
err
error
)
{
return
ErrDailyLimitExceeded
if
b
==
nil
{
return
}
}
b
.
mu
.
Lock
()
defer
b
.
mu
.
Unlock
()
if
!
subscription
.
CheckWeeklyLimit
(
group
,
0
)
{
switch
b
.
state
{
return
ErrWeeklyLimitExceeded
case
billingCircuitOpen
:
return
case
billingCircuitHalfOpen
:
b
.
state
=
billingCircuitOpen
b
.
openedAt
=
time
.
Now
()
b
.
halfOpenRemaining
=
0
log
.
Printf
(
"ALERT: billing circuit breaker opened after half-open failure: %v"
,
err
)
return
default
:
b
.
failures
++
if
b
.
failures
>=
b
.
failureThreshold
{
b
.
state
=
billingCircuitOpen
b
.
openedAt
=
time
.
Now
()
b
.
halfOpenRemaining
=
0
log
.
Printf
(
"ALERT: billing circuit breaker opened after %d failures: %v"
,
b
.
failures
,
err
)
}
}
}
}
if
!
subscription
.
CheckMonthlyLimit
(
group
,
0
)
{
func
(
b
*
billingCircuitBreaker
)
OnSuccess
()
{
return
ErrMonthlyLimitExceeded
if
b
==
nil
{
return
}
}
b
.
mu
.
Lock
()
defer
b
.
mu
.
Unlock
()
return
nil
previousState
:=
b
.
state
previousFailures
:=
b
.
failures
b
.
state
=
billingCircuitClosed
b
.
failures
=
0
b
.
halfOpenRemaining
=
0
// 只有状态真正发生变化时才记录日志
if
previousState
!=
billingCircuitClosed
{
log
.
Printf
(
"ALERT: billing circuit breaker closed (was %s)"
,
circuitStateString
(
previousState
))
}
else
if
previousFailures
>
0
{
log
.
Printf
(
"INFO: billing circuit breaker failures reset from %d"
,
previousFailures
)
}
}
func
circuitStateString
(
state
billingCircuitBreakerState
)
string
{
switch
state
{
case
billingCircuitClosed
:
return
"closed"
case
billingCircuitOpen
:
return
"open"
case
billingCircuitHalfOpen
:
return
"half-open"
default
:
return
"unknown"
}
}
}
backend/internal/service/crs_sync_service.go
View file @
46dda583
...
@@ -8,12 +8,13 @@ import (
...
@@ -8,12 +8,13 @@ import (
"fmt"
"fmt"
"io"
"io"
"net/http"
"net/http"
"net/url"
"strconv"
"strconv"
"strings"
"strings"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
)
type
CRSSyncService
struct
{
type
CRSSyncService
struct
{
...
@@ -22,6 +23,7 @@ type CRSSyncService struct {
...
@@ -22,6 +23,7 @@ type CRSSyncService struct {
oauthService
*
OAuthService
oauthService
*
OAuthService
openaiOAuthService
*
OpenAIOAuthService
openaiOAuthService
*
OpenAIOAuthService
geminiOAuthService
*
GeminiOAuthService
geminiOAuthService
*
GeminiOAuthService
cfg
*
config
.
Config
}
}
func
NewCRSSyncService
(
func
NewCRSSyncService
(
...
@@ -30,6 +32,7 @@ func NewCRSSyncService(
...
@@ -30,6 +32,7 @@ func NewCRSSyncService(
oauthService
*
OAuthService
,
oauthService
*
OAuthService
,
openaiOAuthService
*
OpenAIOAuthService
,
openaiOAuthService
*
OpenAIOAuthService
,
geminiOAuthService
*
GeminiOAuthService
,
geminiOAuthService
*
GeminiOAuthService
,
cfg
*
config
.
Config
,
)
*
CRSSyncService
{
)
*
CRSSyncService
{
return
&
CRSSyncService
{
return
&
CRSSyncService
{
accountRepo
:
accountRepo
,
accountRepo
:
accountRepo
,
...
@@ -37,6 +40,7 @@ func NewCRSSyncService(
...
@@ -37,6 +40,7 @@ func NewCRSSyncService(
oauthService
:
oauthService
,
oauthService
:
oauthService
,
openaiOAuthService
:
openaiOAuthService
,
openaiOAuthService
:
openaiOAuthService
,
geminiOAuthService
:
geminiOAuthService
,
geminiOAuthService
:
geminiOAuthService
,
cfg
:
cfg
,
}
}
}
}
...
@@ -187,7 +191,10 @@ type crsGeminiAPIKeyAccount struct {
...
@@ -187,7 +191,10 @@ type crsGeminiAPIKeyAccount struct {
}
}
func
(
s
*
CRSSyncService
)
SyncFromCRS
(
ctx
context
.
Context
,
input
SyncFromCRSInput
)
(
*
SyncFromCRSResult
,
error
)
{
func
(
s
*
CRSSyncService
)
SyncFromCRS
(
ctx
context
.
Context
,
input
SyncFromCRSInput
)
(
*
SyncFromCRSResult
,
error
)
{
baseURL
,
err
:=
normalizeBaseURL
(
input
.
BaseURL
)
if
s
.
cfg
==
nil
{
return
nil
,
errors
.
New
(
"config is not available"
)
}
baseURL
,
err
:=
normalizeBaseURL
(
input
.
BaseURL
,
s
.
cfg
.
Security
.
URLAllowlist
.
CRSHosts
,
s
.
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
...
@@ -197,6 +204,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
...
@@ -197,6 +204,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
client
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
client
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
Timeout
:
20
*
time
.
Second
,
Timeout
:
20
*
time
.
Second
,
ValidateResolvedIP
:
true
,
AllowPrivateHosts
:
s
.
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
,
})
})
if
err
!=
nil
{
if
err
!=
nil
{
client
=
&
http
.
Client
{
Timeout
:
20
*
time
.
Second
}
client
=
&
http
.
Client
{
Timeout
:
20
*
time
.
Second
}
...
@@ -1055,17 +1064,18 @@ func mapCRSStatus(isActive bool, status string) string {
...
@@ -1055,17 +1064,18 @@ func mapCRSStatus(isActive bool, status string) string {
return
"active"
return
"active"
}
}
func
normalizeBaseURL
(
raw
string
)
(
string
,
error
)
{
func
normalizeBaseURL
(
raw
string
,
allowlist
[]
string
,
allowPrivate
bool
)
(
string
,
error
)
{
trimmed
:=
strings
.
TrimSpace
(
raw
)
// 当 allowlist 为空时,不强制要求白名单(只进行基本的 URL 和 SSRF 验证)
if
trimmed
==
""
{
requireAllowlist
:=
len
(
allowlist
)
>
0
return
""
,
errors
.
New
(
"base_url is required"
)
normalized
,
err
:=
urlvalidator
.
ValidateHTTPSURL
(
raw
,
urlvalidator
.
ValidationOptions
{
}
AllowedHosts
:
allowlist
,
u
,
err
:=
url
.
Parse
(
trimmed
)
RequireAllowlist
:
requireAllowlist
,
if
err
!=
nil
||
u
.
Scheme
==
""
||
u
.
Host
==
""
{
AllowPrivate
:
allowPrivate
,
return
""
,
fmt
.
Errorf
(
"invalid base_url: %s"
,
trimmed
)
})
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"invalid base_url: %w"
,
err
)
}
}
u
.
Path
=
strings
.
TrimRight
(
u
.
Path
,
"/"
)
return
normalized
,
nil
return
strings
.
TrimRight
(
u
.
String
(),
"/"
),
nil
}
}
// cleanBaseURL removes trailing suffix from base_url in credentials
// cleanBaseURL removes trailing suffix from base_url in credentials
...
...
backend/internal/service/gateway_service.go
View file @
46dda583
...
@@ -15,11 +15,14 @@ import (
...
@@ -15,11 +15,14 @@ import (
"regexp"
"regexp"
"sort"
"sort"
"strings"
"strings"
"sync/atomic"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/tidwall/gjson"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/tidwall/sjson"
...
@@ -30,6 +33,7 @@ const (
...
@@ -30,6 +33,7 @@ const (
claudeAPIURL
=
"https://api.anthropic.com/v1/messages?beta=true"
claudeAPIURL
=
"https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL
=
"https://api.anthropic.com/v1/messages/count_tokens?beta=true"
claudeAPICountTokensURL
=
"https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL
=
time
.
Hour
// 粘性会话TTL
stickySessionTTL
=
time
.
Hour
// 粘性会话TTL
defaultMaxLineSize
=
10
*
1024
*
1024
claudeCodeSystemPrompt
=
"You are Claude Code, Anthropic's official CLI for Claude."
claudeCodeSystemPrompt
=
"You are Claude Code, Anthropic's official CLI for Claude."
)
)
...
@@ -1225,7 +1229,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
...
@@ -1225,7 +1229,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
targetURL
:=
claudeAPIURL
targetURL
:=
claudeAPIURL
if
account
.
Type
==
AccountTypeAPIKey
{
if
account
.
Type
==
AccountTypeAPIKey
{
baseURL
:=
account
.
GetBaseURL
()
baseURL
:=
account
.
GetBaseURL
()
targetURL
=
baseURL
+
"/v1/messages"
if
baseURL
!=
""
{
validatedURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
err
}
targetURL
=
validatedURL
+
"/v1/messages"
}
}
}
// OAuth账号:应用统一指纹
// OAuth账号:应用统一指纹
...
@@ -1594,12 +1604,87 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
...
@@ -1594,12 +1604,87 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
var
firstTokenMs
*
int
var
firstTokenMs
*
int
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
scanner
:=
bufio
.
NewScanner
(
resp
.
Body
)
// 设置更大的buffer以处理长行
// 设置更大的buffer以处理长行
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
1024
*
1024
)
maxLineSize
:=
defaultMaxLineSize
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
MaxLineSize
>
0
{
maxLineSize
=
s
.
cfg
.
Gateway
.
MaxLineSize
}
scanner
.
Buffer
(
make
([]
byte
,
64
*
1024
),
maxLineSize
)
type
scanEvent
struct
{
line
string
err
error
}
// 独立 goroutine 读取上游,避免读取阻塞导致超时/keepalive无法处理
events
:=
make
(
chan
scanEvent
,
16
)
done
:=
make
(
chan
struct
{})
sendEvent
:=
func
(
ev
scanEvent
)
bool
{
select
{
case
events
<-
ev
:
return
true
case
<-
done
:
return
false
}
}
var
lastReadAt
int64
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
go
func
()
{
defer
close
(
events
)
for
scanner
.
Scan
()
{
atomic
.
StoreInt64
(
&
lastReadAt
,
time
.
Now
()
.
UnixNano
())
if
!
sendEvent
(
scanEvent
{
line
:
scanner
.
Text
()})
{
return
}
}
if
err
:=
scanner
.
Err
();
err
!=
nil
{
_
=
sendEvent
(
scanEvent
{
err
:
err
})
}
}()
defer
close
(
done
)
streamInterval
:=
time
.
Duration
(
0
)
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
>
0
{
streamInterval
=
time
.
Duration
(
s
.
cfg
.
Gateway
.
StreamDataIntervalTimeout
)
*
time
.
Second
}
// 仅监控上游数据间隔超时,避免下游写入阻塞导致误判
var
intervalTicker
*
time
.
Ticker
if
streamInterval
>
0
{
intervalTicker
=
time
.
NewTicker
(
streamInterval
)
defer
intervalTicker
.
Stop
()
}
var
intervalCh
<-
chan
time
.
Time
if
intervalTicker
!=
nil
{
intervalCh
=
intervalTicker
.
C
}
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
errorEventSent
:=
false
sendErrorEvent
:=
func
(
reason
string
)
{
if
errorEventSent
{
return
}
errorEventSent
=
true
_
,
_
=
fmt
.
Fprintf
(
w
,
"event: error
\n
data: {
\"
error
\"
:
\"
%s
\"
}
\n\n
"
,
reason
)
flusher
.
Flush
()
}
needModelReplace
:=
originalModel
!=
mappedModel
needModelReplace
:=
originalModel
!=
mappedModel
for
scanner
.
Scan
()
{
for
{
line
:=
scanner
.
Text
()
select
{
case
ev
,
ok
:=
<-
events
:
if
!
ok
{
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
if
ev
.
err
!=
nil
{
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
log
.
Printf
(
"SSE line too long: account=%d max_size=%d error=%v"
,
account
.
ID
,
maxLineSize
,
ev
.
err
)
sendErrorEvent
(
"response_too_large"
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
ev
.
err
}
sendErrorEvent
(
"stream_read_error"
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream read error: %w"
,
ev
.
err
)
}
line
:=
ev
.
line
if
line
==
"event: error"
{
if
line
==
"event: error"
{
return
nil
,
errors
.
New
(
"have error in stream"
)
return
nil
,
errors
.
New
(
"have error in stream"
)
}
}
...
@@ -1615,6 +1700,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
...
@@ -1615,6 +1700,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
// 转发行
// 转发行
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
}
flusher
.
Flush
()
flusher
.
Flush
()
...
@@ -1628,17 +1714,23 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
...
@@ -1628,17 +1714,23 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
else
{
}
else
{
// 非 data 行直接转发
// 非 data 行直接转发
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
}
}
flusher
.
Flush
()
flusher
.
Flush
()
}
}
}
if
err
:=
scanner
.
Err
();
err
!=
nil
{
case
<-
intervalCh
:
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream read error: %w"
,
err
)
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
if
time
.
Since
(
lastRead
)
<
streamInterval
{
continue
}
log
.
Printf
(
"Stream data interval timeout: account=%d model=%s interval=%s"
,
account
.
ID
,
originalModel
,
streamInterval
)
sendErrorEvent
(
"stream_timeout"
)
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
}
}
}
return
&
streamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
}
// replaceModelInSSELine 替换SSE数据行中的model字段
// replaceModelInSSELine 替换SSE数据行中的model字段
...
@@ -1743,12 +1835,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
...
@@ -1743,12 +1835,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
body
=
s
.
replaceModelInResponseBody
(
body
,
mappedModel
,
originalModel
)
body
=
s
.
replaceModelInResponseBody
(
body
,
mappedModel
,
originalModel
)
}
}
// 透传响应头
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
for
key
,
values
:=
range
resp
.
Header
{
for
_
,
value
:=
range
values
{
c
.
Header
(
key
,
value
)
}
}
// 写入响应
// 写入响应
c
.
Data
(
resp
.
StatusCode
,
"application/json"
,
body
)
c
.
Data
(
resp
.
StatusCode
,
"application/json"
,
body
)
...
@@ -2020,7 +2107,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
...
@@ -2020,7 +2107,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
targetURL
:=
claudeAPICountTokensURL
targetURL
:=
claudeAPICountTokensURL
if
account
.
Type
==
AccountTypeAPIKey
{
if
account
.
Type
==
AccountTypeAPIKey
{
baseURL
:=
account
.
GetBaseURL
()
baseURL
:=
account
.
GetBaseURL
()
targetURL
=
baseURL
+
"/v1/messages/count_tokens"
if
baseURL
!=
""
{
validatedURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
err
}
targetURL
=
validatedURL
+
"/v1/messages/count_tokens"
}
}
}
// OAuth 账号:应用统一指纹和重写 userID
// OAuth 账号:应用统一指纹和重写 userID
...
@@ -2100,6 +2193,18 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
...
@@ -2100,6 +2193,18 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
})
})
}
}
func
(
s
*
GatewayService
)
validateUpstreamBaseURL
(
raw
string
)
(
string
,
error
)
{
normalized
,
err
:=
urlvalidator
.
ValidateHTTPSURL
(
raw
,
urlvalidator
.
ValidationOptions
{
AllowedHosts
:
s
.
cfg
.
Security
.
URLAllowlist
.
UpstreamHosts
,
RequireAllowlist
:
true
,
AllowPrivate
:
s
.
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
,
})
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"invalid base_url: %w"
,
err
)
}
return
normalized
,
nil
}
// GetAvailableModels returns the list of models available for a group
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
// It aggregates model_mapping keys from all schedulable accounts in the group
func
(
s
*
GatewayService
)
GetAvailableModels
(
ctx
context
.
Context
,
groupID
*
int64
,
platform
string
)
[]
string
{
func
(
s
*
GatewayService
)
GetAvailableModels
(
ctx
context
.
Context
,
groupID
*
int64
,
platform
string
)
[]
string
{
...
...
backend/internal/service/gemini_messages_compat_service.go
View file @
46dda583
...
@@ -18,9 +18,12 @@ import (
...
@@ -18,9 +18,12 @@ import (
"strings"
"strings"
"time"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin"
)
)
...
@@ -41,6 +44,7 @@ type GeminiMessagesCompatService struct {
...
@@ -41,6 +44,7 @@ type GeminiMessagesCompatService struct {
rateLimitService
*
RateLimitService
rateLimitService
*
RateLimitService
httpUpstream
HTTPUpstream
httpUpstream
HTTPUpstream
antigravityGatewayService
*
AntigravityGatewayService
antigravityGatewayService
*
AntigravityGatewayService
cfg
*
config
.
Config
}
}
func
NewGeminiMessagesCompatService
(
func
NewGeminiMessagesCompatService
(
...
@@ -51,6 +55,7 @@ func NewGeminiMessagesCompatService(
...
@@ -51,6 +55,7 @@ func NewGeminiMessagesCompatService(
rateLimitService
*
RateLimitService
,
rateLimitService
*
RateLimitService
,
httpUpstream
HTTPUpstream
,
httpUpstream
HTTPUpstream
,
antigravityGatewayService
*
AntigravityGatewayService
,
antigravityGatewayService
*
AntigravityGatewayService
,
cfg
*
config
.
Config
,
)
*
GeminiMessagesCompatService
{
)
*
GeminiMessagesCompatService
{
return
&
GeminiMessagesCompatService
{
return
&
GeminiMessagesCompatService
{
accountRepo
:
accountRepo
,
accountRepo
:
accountRepo
,
...
@@ -60,6 +65,7 @@ func NewGeminiMessagesCompatService(
...
@@ -60,6 +65,7 @@ func NewGeminiMessagesCompatService(
rateLimitService
:
rateLimitService
,
rateLimitService
:
rateLimitService
,
httpUpstream
:
httpUpstream
,
httpUpstream
:
httpUpstream
,
antigravityGatewayService
:
antigravityGatewayService
,
antigravityGatewayService
:
antigravityGatewayService
,
cfg
:
cfg
,
}
}
}
}
...
@@ -230,6 +236,18 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
...
@@ -230,6 +236,18 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
return
s
.
antigravityGatewayService
return
s
.
antigravityGatewayService
}
}
func
(
s
*
GeminiMessagesCompatService
)
validateUpstreamBaseURL
(
raw
string
)
(
string
,
error
)
{
normalized
,
err
:=
urlvalidator
.
ValidateHTTPSURL
(
raw
,
urlvalidator
.
ValidationOptions
{
AllowedHosts
:
s
.
cfg
.
Security
.
URLAllowlist
.
UpstreamHosts
,
RequireAllowlist
:
true
,
AllowPrivate
:
s
.
cfg
.
Security
.
URLAllowlist
.
AllowPrivateHosts
,
})
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"invalid base_url: %w"
,
err
)
}
return
normalized
,
nil
}
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
func
(
s
*
GeminiMessagesCompatService
)
HasAntigravityAccounts
(
ctx
context
.
Context
,
groupID
*
int64
)
(
bool
,
error
)
{
func
(
s
*
GeminiMessagesCompatService
)
HasAntigravityAccounts
(
ctx
context
.
Context
,
groupID
*
int64
)
(
bool
,
error
)
{
var
accounts
[]
Account
var
accounts
[]
Account
...
@@ -381,16 +399,20 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
...
@@ -381,16 +399,20 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return
nil
,
""
,
errors
.
New
(
"gemini api_key not configured"
)
return
nil
,
""
,
errors
.
New
(
"gemini api_key not configured"
)
}
}
baseURL
:=
strings
.
Trim
Right
(
account
.
GetCredential
(
"base_url"
)
,
"/"
)
baseURL
:=
strings
.
Trim
Space
(
account
.
GetCredential
(
"base_url"
))
if
baseURL
==
""
{
if
baseURL
==
""
{
baseURL
=
geminicli
.
AIStudioBaseURL
baseURL
=
geminicli
.
AIStudioBaseURL
}
}
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
action
:=
"generateContent"
action
:=
"generateContent"
if
req
.
Stream
{
if
req
.
Stream
{
action
=
"streamGenerateContent"
action
=
"streamGenerateContent"
}
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:%s"
,
strings
.
TrimRight
(
b
aseURL
,
"/"
),
mappedModel
,
action
)
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:%s"
,
strings
.
TrimRight
(
normalizedB
aseURL
,
"/"
),
mappedModel
,
action
)
if
req
.
Stream
{
if
req
.
Stream
{
fullURL
+=
"?alt=sse"
fullURL
+=
"?alt=sse"
}
}
...
@@ -427,7 +449,11 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
...
@@ -427,7 +449,11 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if
projectID
!=
""
{
if
projectID
!=
""
{
// Mode 1: Code Assist API
// Mode 1: Code Assist API
fullURL
:=
fmt
.
Sprintf
(
"%s/v1internal:%s"
,
geminicli
.
GeminiCliBaseURL
,
action
)
baseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
geminicli
.
GeminiCliBaseURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1internal:%s"
,
strings
.
TrimRight
(
baseURL
,
"/"
),
action
)
if
useUpstreamStream
{
if
useUpstreamStream
{
fullURL
+=
"?alt=sse"
fullURL
+=
"?alt=sse"
}
}
...
@@ -453,12 +479,16 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
...
@@ -453,12 +479,16 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return
upstreamReq
,
"x-request-id"
,
nil
return
upstreamReq
,
"x-request-id"
,
nil
}
else
{
}
else
{
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL
:=
strings
.
Trim
Right
(
account
.
GetCredential
(
"base_url"
)
,
"/"
)
baseURL
:=
strings
.
Trim
Space
(
account
.
GetCredential
(
"base_url"
))
if
baseURL
==
""
{
if
baseURL
==
""
{
baseURL
=
geminicli
.
AIStudioBaseURL
baseURL
=
geminicli
.
AIStudioBaseURL
}
}
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:%s"
,
baseURL
,
mappedModel
,
action
)
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:%s"
,
strings
.
TrimRight
(
normalizedBaseURL
,
"/"
)
,
mappedModel
,
action
)
if
useUpstreamStream
{
if
useUpstreamStream
{
fullURL
+=
"?alt=sse"
fullURL
+=
"?alt=sse"
}
}
...
@@ -650,12 +680,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
...
@@ -650,12 +680,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return
nil
,
""
,
errors
.
New
(
"gemini api_key not configured"
)
return
nil
,
""
,
errors
.
New
(
"gemini api_key not configured"
)
}
}
baseURL
:=
strings
.
Trim
Right
(
account
.
GetCredential
(
"base_url"
)
,
"/"
)
baseURL
:=
strings
.
Trim
Space
(
account
.
GetCredential
(
"base_url"
))
if
baseURL
==
""
{
if
baseURL
==
""
{
baseURL
=
geminicli
.
AIStudioBaseURL
baseURL
=
geminicli
.
AIStudioBaseURL
}
}
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:%s"
,
strings
.
TrimRight
(
b
aseURL
,
"/"
),
mappedModel
,
upstreamAction
)
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:%s"
,
strings
.
TrimRight
(
normalizedB
aseURL
,
"/"
),
mappedModel
,
upstreamAction
)
if
useUpstreamStream
{
if
useUpstreamStream
{
fullURL
+=
"?alt=sse"
fullURL
+=
"?alt=sse"
}
}
...
@@ -687,7 +721,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
...
@@ -687,7 +721,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
// 2. Without project_id -> AI Studio API (direct OAuth, like API key but with Bearer token)
if
projectID
!=
""
&&
!
forceAIStudio
{
if
projectID
!=
""
&&
!
forceAIStudio
{
// Mode 1: Code Assist API
// Mode 1: Code Assist API
fullURL
:=
fmt
.
Sprintf
(
"%s/v1internal:%s"
,
geminicli
.
GeminiCliBaseURL
,
upstreamAction
)
baseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
geminicli
.
GeminiCliBaseURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1internal:%s"
,
strings
.
TrimRight
(
baseURL
,
"/"
),
upstreamAction
)
if
useUpstreamStream
{
if
useUpstreamStream
{
fullURL
+=
"?alt=sse"
fullURL
+=
"?alt=sse"
}
}
...
@@ -713,12 +751,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
...
@@ -713,12 +751,16 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return
upstreamReq
,
"x-request-id"
,
nil
return
upstreamReq
,
"x-request-id"
,
nil
}
else
{
}
else
{
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL
:=
strings
.
Trim
Right
(
account
.
GetCredential
(
"base_url"
)
,
"/"
)
baseURL
:=
strings
.
Trim
Space
(
account
.
GetCredential
(
"base_url"
))
if
baseURL
==
""
{
if
baseURL
==
""
{
baseURL
=
geminicli
.
AIStudioBaseURL
baseURL
=
geminicli
.
AIStudioBaseURL
}
}
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
""
,
err
}
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:%s"
,
baseURL
,
mappedModel
,
upstreamAction
)
fullURL
:=
fmt
.
Sprintf
(
"%s/v1beta/models/%s:%s"
,
strings
.
TrimRight
(
normalizedBaseURL
,
"/"
)
,
mappedModel
,
upstreamAction
)
if
useUpstreamStream
{
if
useUpstreamStream
{
fullURL
+=
"?alt=sse"
fullURL
+=
"?alt=sse"
}
}
...
@@ -1652,6 +1694,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
...
@@ -1652,6 +1694,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
_
=
json
.
Unmarshal
(
respBody
,
&
parsed
)
_
=
json
.
Unmarshal
(
respBody
,
&
parsed
)
}
}
responseheaders
.
WriteFilteredHeaders
(
c
.
Writer
.
Header
(),
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
contentType
:=
resp
.
Header
.
Get
(
"Content-Type"
)
contentType
:=
resp
.
Header
.
Get
(
"Content-Type"
)
if
contentType
==
""
{
if
contentType
==
""
{
contentType
=
"application/json"
contentType
=
"application/json"
...
@@ -1773,11 +1817,15 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
...
@@ -1773,11 +1817,15 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
return
nil
,
errors
.
New
(
"invalid path"
)
return
nil
,
errors
.
New
(
"invalid path"
)
}
}
baseURL
:=
strings
.
Trim
Right
(
account
.
GetCredential
(
"base_url"
)
,
"/"
)
baseURL
:=
strings
.
Trim
Space
(
account
.
GetCredential
(
"base_url"
))
if
baseURL
==
""
{
if
baseURL
==
""
{
baseURL
=
geminicli
.
AIStudioBaseURL
baseURL
=
geminicli
.
AIStudioBaseURL
}
}
fullURL
:=
strings
.
TrimRight
(
baseURL
,
"/"
)
+
path
normalizedBaseURL
,
err
:=
s
.
validateUpstreamBaseURL
(
baseURL
)
if
err
!=
nil
{
return
nil
,
err
}
fullURL
:=
strings
.
TrimRight
(
normalizedBaseURL
,
"/"
)
+
path
var
proxyURL
string
var
proxyURL
string
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
if
account
.
ProxyID
!=
nil
&&
account
.
Proxy
!=
nil
{
...
@@ -1816,9 +1864,14 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
...
@@ -1816,9 +1864,14 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
body
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
8
<<
20
))
body
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
8
<<
20
))
wwwAuthenticate
:=
resp
.
Header
.
Get
(
"Www-Authenticate"
)
filteredHeaders
:=
responseheaders
.
FilterHeaders
(
resp
.
Header
,
s
.
cfg
.
Security
.
ResponseHeaders
)
if
wwwAuthenticate
!=
""
{
filteredHeaders
.
Set
(
"Www-Authenticate"
,
wwwAuthenticate
)
}
return
&
UpstreamHTTPResult
{
return
&
UpstreamHTTPResult
{
StatusCode
:
resp
.
StatusCode
,
StatusCode
:
resp
.
StatusCode
,
Headers
:
resp
.
Header
.
Clone
()
,
Headers
:
filteredHeaders
,
Body
:
body
,
Body
:
body
,
},
nil
},
nil
}
}
...
...
backend/internal/service/gemini_oauth_service.go
View file @
46dda583
...
@@ -1002,6 +1002,7 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
...
@@ -1002,6 +1002,7 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
client
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
client
,
err
:=
httpclient
.
GetClient
(
httpclient
.
Options
{
ProxyURL
:
strings
.
TrimSpace
(
proxyURL
),
ProxyURL
:
strings
.
TrimSpace
(
proxyURL
),
Timeout
:
30
*
time
.
Second
,
Timeout
:
30
*
time
.
Second
,
ValidateResolvedIP
:
true
,
})
})
if
err
!=
nil
{
if
err
!=
nil
{
client
=
&
http
.
Client
{
Timeout
:
30
*
time
.
Second
}
client
=
&
http
.
Client
{
Timeout
:
30
*
time
.
Second
}
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment