Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
11bfc807
Commit
11bfc807
authored
Jan 09, 2026
by
song
Browse files
merge upstream/main
parents
c2a6ca8d
62dc0b95
Changes
101
Hide whitespace changes
Inline
Side-by-side
backend/internal/handler/auth_handler.go
View file @
11bfc807
...
...
@@ -15,14 +15,16 @@ type AuthHandler struct {
cfg
*
config
.
Config
authService
*
service
.
AuthService
userService
*
service
.
UserService
settingSvc
*
service
.
SettingService
}
// NewAuthHandler creates a new AuthHandler
func
NewAuthHandler
(
cfg
*
config
.
Config
,
authService
*
service
.
AuthService
,
userService
*
service
.
UserService
)
*
AuthHandler
{
func
NewAuthHandler
(
cfg
*
config
.
Config
,
authService
*
service
.
AuthService
,
userService
*
service
.
UserService
,
settingService
*
service
.
SettingService
)
*
AuthHandler
{
return
&
AuthHandler
{
cfg
:
cfg
,
authService
:
authService
,
userService
:
userService
,
settingSvc
:
settingService
,
}
}
...
...
backend/internal/handler/auth_linuxdo_oauth.go
0 → 100644
View file @
11bfc807
package
handler
import
(
"context"
"encoding/base64"
"errors"
"fmt"
"log"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/imroc/req/v3"
"github.com/tidwall/gjson"
)
const
(
linuxDoOAuthCookiePath
=
"/api/v1/auth/oauth/linuxdo"
linuxDoOAuthStateCookieName
=
"linuxdo_oauth_state"
linuxDoOAuthVerifierCookie
=
"linuxdo_oauth_verifier"
linuxDoOAuthRedirectCookie
=
"linuxdo_oauth_redirect"
linuxDoOAuthCookieMaxAgeSec
=
10
*
60
// 10 minutes
linuxDoOAuthDefaultRedirectTo
=
"/dashboard"
linuxDoOAuthDefaultFrontendCB
=
"/auth/linuxdo/callback"
linuxDoOAuthMaxRedirectLen
=
2048
linuxDoOAuthMaxFragmentValueLen
=
512
linuxDoOAuthMaxSubjectLen
=
64
-
len
(
"linuxdo-"
)
)
type
linuxDoTokenResponse
struct
{
AccessToken
string
`json:"access_token"`
TokenType
string
`json:"token_type"`
ExpiresIn
int64
`json:"expires_in"`
RefreshToken
string
`json:"refresh_token,omitempty"`
Scope
string
`json:"scope,omitempty"`
}
type
linuxDoTokenExchangeError
struct
{
StatusCode
int
ProviderError
string
ProviderDescription
string
Body
string
}
func
(
e
*
linuxDoTokenExchangeError
)
Error
()
string
{
if
e
==
nil
{
return
""
}
parts
:=
[]
string
{
fmt
.
Sprintf
(
"token exchange status=%d"
,
e
.
StatusCode
)}
if
strings
.
TrimSpace
(
e
.
ProviderError
)
!=
""
{
parts
=
append
(
parts
,
"error="
+
strings
.
TrimSpace
(
e
.
ProviderError
))
}
if
strings
.
TrimSpace
(
e
.
ProviderDescription
)
!=
""
{
parts
=
append
(
parts
,
"error_description="
+
strings
.
TrimSpace
(
e
.
ProviderDescription
))
}
return
strings
.
Join
(
parts
,
" "
)
}
// LinuxDoOAuthStart 启动 LinuxDo Connect OAuth 登录流程。
// GET /api/v1/auth/oauth/linuxdo/start?redirect=/dashboard
func
(
h
*
AuthHandler
)
LinuxDoOAuthStart
(
c
*
gin
.
Context
)
{
cfg
,
err
:=
h
.
getLinuxDoOAuthConfig
(
c
.
Request
.
Context
())
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
state
,
err
:=
oauth
.
GenerateState
()
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"OAUTH_STATE_GEN_FAILED"
,
"failed to generate oauth state"
)
.
WithCause
(
err
))
return
}
redirectTo
:=
sanitizeFrontendRedirectPath
(
c
.
Query
(
"redirect"
))
if
redirectTo
==
""
{
redirectTo
=
linuxDoOAuthDefaultRedirectTo
}
secureCookie
:=
isRequestHTTPS
(
c
)
setCookie
(
c
,
linuxDoOAuthStateCookieName
,
encodeCookieValue
(
state
),
linuxDoOAuthCookieMaxAgeSec
,
secureCookie
)
setCookie
(
c
,
linuxDoOAuthRedirectCookie
,
encodeCookieValue
(
redirectTo
),
linuxDoOAuthCookieMaxAgeSec
,
secureCookie
)
codeChallenge
:=
""
if
cfg
.
UsePKCE
{
verifier
,
err
:=
oauth
.
GenerateCodeVerifier
()
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"OAUTH_PKCE_GEN_FAILED"
,
"failed to generate pkce verifier"
)
.
WithCause
(
err
))
return
}
codeChallenge
=
oauth
.
GenerateCodeChallenge
(
verifier
)
setCookie
(
c
,
linuxDoOAuthVerifierCookie
,
encodeCookieValue
(
verifier
),
linuxDoOAuthCookieMaxAgeSec
,
secureCookie
)
}
redirectURI
:=
strings
.
TrimSpace
(
cfg
.
RedirectURL
)
if
redirectURI
==
""
{
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"OAUTH_CONFIG_INVALID"
,
"oauth redirect url not configured"
))
return
}
authURL
,
err
:=
buildLinuxDoAuthorizeURL
(
cfg
,
state
,
codeChallenge
,
redirectURI
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"OAUTH_BUILD_URL_FAILED"
,
"failed to build oauth authorization url"
)
.
WithCause
(
err
))
return
}
c
.
Redirect
(
http
.
StatusFound
,
authURL
)
}
// LinuxDoOAuthCallback 处理 OAuth 回调:创建/登录用户,然后重定向到前端。
// GET /api/v1/auth/oauth/linuxdo/callback?code=...&state=...
func
(
h
*
AuthHandler
)
LinuxDoOAuthCallback
(
c
*
gin
.
Context
)
{
cfg
,
cfgErr
:=
h
.
getLinuxDoOAuthConfig
(
c
.
Request
.
Context
())
if
cfgErr
!=
nil
{
response
.
ErrorFrom
(
c
,
cfgErr
)
return
}
frontendCallback
:=
strings
.
TrimSpace
(
cfg
.
FrontendRedirectURL
)
if
frontendCallback
==
""
{
frontendCallback
=
linuxDoOAuthDefaultFrontendCB
}
if
providerErr
:=
strings
.
TrimSpace
(
c
.
Query
(
"error"
));
providerErr
!=
""
{
redirectOAuthError
(
c
,
frontendCallback
,
"provider_error"
,
providerErr
,
c
.
Query
(
"error_description"
))
return
}
code
:=
strings
.
TrimSpace
(
c
.
Query
(
"code"
))
state
:=
strings
.
TrimSpace
(
c
.
Query
(
"state"
))
if
code
==
""
||
state
==
""
{
redirectOAuthError
(
c
,
frontendCallback
,
"missing_params"
,
"missing code/state"
,
""
)
return
}
secureCookie
:=
isRequestHTTPS
(
c
)
defer
func
()
{
clearCookie
(
c
,
linuxDoOAuthStateCookieName
,
secureCookie
)
clearCookie
(
c
,
linuxDoOAuthVerifierCookie
,
secureCookie
)
clearCookie
(
c
,
linuxDoOAuthRedirectCookie
,
secureCookie
)
}()
expectedState
,
err
:=
readCookieDecoded
(
c
,
linuxDoOAuthStateCookieName
)
if
err
!=
nil
||
expectedState
==
""
||
state
!=
expectedState
{
redirectOAuthError
(
c
,
frontendCallback
,
"invalid_state"
,
"invalid oauth state"
,
""
)
return
}
redirectTo
,
_
:=
readCookieDecoded
(
c
,
linuxDoOAuthRedirectCookie
)
redirectTo
=
sanitizeFrontendRedirectPath
(
redirectTo
)
if
redirectTo
==
""
{
redirectTo
=
linuxDoOAuthDefaultRedirectTo
}
codeVerifier
:=
""
if
cfg
.
UsePKCE
{
codeVerifier
,
_
=
readCookieDecoded
(
c
,
linuxDoOAuthVerifierCookie
)
if
codeVerifier
==
""
{
redirectOAuthError
(
c
,
frontendCallback
,
"missing_verifier"
,
"missing pkce verifier"
,
""
)
return
}
}
redirectURI
:=
strings
.
TrimSpace
(
cfg
.
RedirectURL
)
if
redirectURI
==
""
{
redirectOAuthError
(
c
,
frontendCallback
,
"config_error"
,
"oauth redirect url not configured"
,
""
)
return
}
tokenResp
,
err
:=
linuxDoExchangeCode
(
c
.
Request
.
Context
(),
cfg
,
code
,
redirectURI
,
codeVerifier
)
if
err
!=
nil
{
description
:=
""
var
exchangeErr
*
linuxDoTokenExchangeError
if
errors
.
As
(
err
,
&
exchangeErr
)
&&
exchangeErr
!=
nil
{
log
.
Printf
(
"[LinuxDo OAuth] token exchange failed: status=%d provider_error=%q provider_description=%q body=%s"
,
exchangeErr
.
StatusCode
,
exchangeErr
.
ProviderError
,
exchangeErr
.
ProviderDescription
,
truncateLogValue
(
exchangeErr
.
Body
,
2048
),
)
description
=
exchangeErr
.
Error
()
}
else
{
log
.
Printf
(
"[LinuxDo OAuth] token exchange failed: %v"
,
err
)
description
=
err
.
Error
()
}
redirectOAuthError
(
c
,
frontendCallback
,
"token_exchange_failed"
,
"failed to exchange oauth code"
,
singleLine
(
description
))
return
}
email
,
username
,
subject
,
err
:=
linuxDoFetchUserInfo
(
c
.
Request
.
Context
(),
cfg
,
tokenResp
)
if
err
!=
nil
{
log
.
Printf
(
"[LinuxDo OAuth] userinfo fetch failed: %v"
,
err
)
redirectOAuthError
(
c
,
frontendCallback
,
"userinfo_failed"
,
"failed to fetch user info"
,
""
)
return
}
// 安全考虑:不要把第三方返回的 email 直接映射到本地账号(可能与本地邮箱用户冲突导致账号被接管)。
// 统一使用基于 subject 的稳定合成邮箱来做账号绑定。
if
subject
!=
""
{
email
=
linuxDoSyntheticEmail
(
subject
)
}
jwtToken
,
_
,
err
:=
h
.
authService
.
LoginOrRegisterOAuth
(
c
.
Request
.
Context
(),
email
,
username
)
if
err
!=
nil
{
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
redirectOAuthError
(
c
,
frontendCallback
,
"login_failed"
,
infraerrors
.
Reason
(
err
),
infraerrors
.
Message
(
err
))
return
}
fragment
:=
url
.
Values
{}
fragment
.
Set
(
"access_token"
,
jwtToken
)
fragment
.
Set
(
"token_type"
,
"Bearer"
)
fragment
.
Set
(
"redirect"
,
redirectTo
)
redirectWithFragment
(
c
,
frontendCallback
,
fragment
)
}
func
(
h
*
AuthHandler
)
getLinuxDoOAuthConfig
(
ctx
context
.
Context
)
(
config
.
LinuxDoConnectConfig
,
error
)
{
if
h
!=
nil
&&
h
.
settingSvc
!=
nil
{
return
h
.
settingSvc
.
GetLinuxDoConnectOAuthConfig
(
ctx
)
}
if
h
==
nil
||
h
.
cfg
==
nil
{
return
config
.
LinuxDoConnectConfig
{},
infraerrors
.
ServiceUnavailable
(
"CONFIG_NOT_READY"
,
"config not loaded"
)
}
if
!
h
.
cfg
.
LinuxDo
.
Enabled
{
return
config
.
LinuxDoConnectConfig
{},
infraerrors
.
NotFound
(
"OAUTH_DISABLED"
,
"oauth login is disabled"
)
}
return
h
.
cfg
.
LinuxDo
,
nil
}
func
linuxDoExchangeCode
(
ctx
context
.
Context
,
cfg
config
.
LinuxDoConnectConfig
,
code
string
,
redirectURI
string
,
codeVerifier
string
,
)
(
*
linuxDoTokenResponse
,
error
)
{
client
:=
req
.
C
()
.
SetTimeout
(
30
*
time
.
Second
)
form
:=
url
.
Values
{}
form
.
Set
(
"grant_type"
,
"authorization_code"
)
form
.
Set
(
"client_id"
,
cfg
.
ClientID
)
form
.
Set
(
"code"
,
code
)
form
.
Set
(
"redirect_uri"
,
redirectURI
)
if
cfg
.
UsePKCE
{
form
.
Set
(
"code_verifier"
,
codeVerifier
)
}
r
:=
client
.
R
()
.
SetContext
(
ctx
)
.
SetHeader
(
"Accept"
,
"application/json"
)
switch
strings
.
ToLower
(
strings
.
TrimSpace
(
cfg
.
TokenAuthMethod
))
{
case
""
,
"client_secret_post"
:
form
.
Set
(
"client_secret"
,
cfg
.
ClientSecret
)
case
"client_secret_basic"
:
r
.
SetBasicAuth
(
cfg
.
ClientID
,
cfg
.
ClientSecret
)
case
"none"
:
default
:
return
nil
,
fmt
.
Errorf
(
"unsupported token_auth_method: %s"
,
cfg
.
TokenAuthMethod
)
}
resp
,
err
:=
r
.
SetFormDataFromValues
(
form
)
.
Post
(
cfg
.
TokenURL
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"request token: %w"
,
err
)
}
body
:=
strings
.
TrimSpace
(
resp
.
String
())
if
!
resp
.
IsSuccessState
()
{
providerErr
,
providerDesc
:=
parseOAuthProviderError
(
body
)
return
nil
,
&
linuxDoTokenExchangeError
{
StatusCode
:
resp
.
StatusCode
,
ProviderError
:
providerErr
,
ProviderDescription
:
providerDesc
,
Body
:
body
,
}
}
tokenResp
,
ok
:=
parseLinuxDoTokenResponse
(
body
)
if
!
ok
||
strings
.
TrimSpace
(
tokenResp
.
AccessToken
)
==
""
{
return
nil
,
&
linuxDoTokenExchangeError
{
StatusCode
:
resp
.
StatusCode
,
Body
:
body
,
}
}
if
strings
.
TrimSpace
(
tokenResp
.
TokenType
)
==
""
{
tokenResp
.
TokenType
=
"Bearer"
}
return
tokenResp
,
nil
}
func
linuxDoFetchUserInfo
(
ctx
context
.
Context
,
cfg
config
.
LinuxDoConnectConfig
,
token
*
linuxDoTokenResponse
,
)
(
email
string
,
username
string
,
subject
string
,
err
error
)
{
client
:=
req
.
C
()
.
SetTimeout
(
30
*
time
.
Second
)
authorization
,
err
:=
buildBearerAuthorization
(
token
.
TokenType
,
token
.
AccessToken
)
if
err
!=
nil
{
return
""
,
""
,
""
,
fmt
.
Errorf
(
"invalid token for userinfo request: %w"
,
err
)
}
resp
,
err
:=
client
.
R
()
.
SetContext
(
ctx
)
.
SetHeader
(
"Accept"
,
"application/json"
)
.
SetHeader
(
"Authorization"
,
authorization
)
.
Get
(
cfg
.
UserInfoURL
)
if
err
!=
nil
{
return
""
,
""
,
""
,
fmt
.
Errorf
(
"request userinfo: %w"
,
err
)
}
if
!
resp
.
IsSuccessState
()
{
return
""
,
""
,
""
,
fmt
.
Errorf
(
"userinfo status=%d"
,
resp
.
StatusCode
)
}
return
linuxDoParseUserInfo
(
resp
.
String
(),
cfg
)
}
func
linuxDoParseUserInfo
(
body
string
,
cfg
config
.
LinuxDoConnectConfig
)
(
email
string
,
username
string
,
subject
string
,
err
error
)
{
email
=
firstNonEmpty
(
getGJSON
(
body
,
cfg
.
UserInfoEmailPath
),
getGJSON
(
body
,
"email"
),
getGJSON
(
body
,
"user.email"
),
getGJSON
(
body
,
"data.email"
),
getGJSON
(
body
,
"attributes.email"
),
)
username
=
firstNonEmpty
(
getGJSON
(
body
,
cfg
.
UserInfoUsernamePath
),
getGJSON
(
body
,
"username"
),
getGJSON
(
body
,
"preferred_username"
),
getGJSON
(
body
,
"name"
),
getGJSON
(
body
,
"user.username"
),
getGJSON
(
body
,
"user.name"
),
)
subject
=
firstNonEmpty
(
getGJSON
(
body
,
cfg
.
UserInfoIDPath
),
getGJSON
(
body
,
"sub"
),
getGJSON
(
body
,
"id"
),
getGJSON
(
body
,
"user_id"
),
getGJSON
(
body
,
"uid"
),
getGJSON
(
body
,
"user.id"
),
)
subject
=
strings
.
TrimSpace
(
subject
)
if
subject
==
""
{
return
""
,
""
,
""
,
errors
.
New
(
"userinfo missing id field"
)
}
if
!
isSafeLinuxDoSubject
(
subject
)
{
return
""
,
""
,
""
,
errors
.
New
(
"userinfo returned invalid id field"
)
}
email
=
strings
.
TrimSpace
(
email
)
if
email
==
""
{
// LinuxDo Connect 的 userinfo 可能不提供 email。为兼容现有用户模型(email 必填且唯一),使用稳定的合成邮箱。
email
=
linuxDoSyntheticEmail
(
subject
)
}
username
=
strings
.
TrimSpace
(
username
)
if
username
==
""
{
username
=
"linuxdo_"
+
subject
}
return
email
,
username
,
subject
,
nil
}
func
buildLinuxDoAuthorizeURL
(
cfg
config
.
LinuxDoConnectConfig
,
state
string
,
codeChallenge
string
,
redirectURI
string
)
(
string
,
error
)
{
u
,
err
:=
url
.
Parse
(
cfg
.
AuthorizeURL
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"parse authorize_url: %w"
,
err
)
}
q
:=
u
.
Query
()
q
.
Set
(
"response_type"
,
"code"
)
q
.
Set
(
"client_id"
,
cfg
.
ClientID
)
q
.
Set
(
"redirect_uri"
,
redirectURI
)
if
strings
.
TrimSpace
(
cfg
.
Scopes
)
!=
""
{
q
.
Set
(
"scope"
,
cfg
.
Scopes
)
}
q
.
Set
(
"state"
,
state
)
if
cfg
.
UsePKCE
{
q
.
Set
(
"code_challenge"
,
codeChallenge
)
q
.
Set
(
"code_challenge_method"
,
"S256"
)
}
u
.
RawQuery
=
q
.
Encode
()
return
u
.
String
(),
nil
}
func
redirectOAuthError
(
c
*
gin
.
Context
,
frontendCallback
string
,
code
string
,
message
string
,
description
string
)
{
fragment
:=
url
.
Values
{}
fragment
.
Set
(
"error"
,
truncateFragmentValue
(
code
))
if
strings
.
TrimSpace
(
message
)
!=
""
{
fragment
.
Set
(
"error_message"
,
truncateFragmentValue
(
message
))
}
if
strings
.
TrimSpace
(
description
)
!=
""
{
fragment
.
Set
(
"error_description"
,
truncateFragmentValue
(
description
))
}
redirectWithFragment
(
c
,
frontendCallback
,
fragment
)
}
func
redirectWithFragment
(
c
*
gin
.
Context
,
frontendCallback
string
,
fragment
url
.
Values
)
{
u
,
err
:=
url
.
Parse
(
frontendCallback
)
if
err
!=
nil
{
// 兜底:尽力跳转到默认页面,避免卡死在回调页。
c
.
Redirect
(
http
.
StatusFound
,
linuxDoOAuthDefaultRedirectTo
)
return
}
if
u
.
Scheme
!=
""
&&
!
strings
.
EqualFold
(
u
.
Scheme
,
"http"
)
&&
!
strings
.
EqualFold
(
u
.
Scheme
,
"https"
)
{
c
.
Redirect
(
http
.
StatusFound
,
linuxDoOAuthDefaultRedirectTo
)
return
}
u
.
Fragment
=
fragment
.
Encode
()
c
.
Header
(
"Cache-Control"
,
"no-store"
)
c
.
Header
(
"Pragma"
,
"no-cache"
)
c
.
Redirect
(
http
.
StatusFound
,
u
.
String
())
}
func
firstNonEmpty
(
values
...
string
)
string
{
for
_
,
v
:=
range
values
{
v
=
strings
.
TrimSpace
(
v
)
if
v
!=
""
{
return
v
}
}
return
""
}
func
parseOAuthProviderError
(
body
string
)
(
providerErr
string
,
providerDesc
string
)
{
body
=
strings
.
TrimSpace
(
body
)
if
body
==
""
{
return
""
,
""
}
providerErr
=
firstNonEmpty
(
getGJSON
(
body
,
"error"
),
getGJSON
(
body
,
"code"
),
getGJSON
(
body
,
"error.code"
),
)
providerDesc
=
firstNonEmpty
(
getGJSON
(
body
,
"error_description"
),
getGJSON
(
body
,
"error.message"
),
getGJSON
(
body
,
"message"
),
getGJSON
(
body
,
"detail"
),
)
if
providerErr
!=
""
||
providerDesc
!=
""
{
return
providerErr
,
providerDesc
}
values
,
err
:=
url
.
ParseQuery
(
body
)
if
err
!=
nil
{
return
""
,
""
}
providerErr
=
firstNonEmpty
(
values
.
Get
(
"error"
),
values
.
Get
(
"code"
))
providerDesc
=
firstNonEmpty
(
values
.
Get
(
"error_description"
),
values
.
Get
(
"error_message"
),
values
.
Get
(
"message"
))
return
providerErr
,
providerDesc
}
func
parseLinuxDoTokenResponse
(
body
string
)
(
*
linuxDoTokenResponse
,
bool
)
{
body
=
strings
.
TrimSpace
(
body
)
if
body
==
""
{
return
nil
,
false
}
accessToken
:=
strings
.
TrimSpace
(
getGJSON
(
body
,
"access_token"
))
if
accessToken
!=
""
{
tokenType
:=
strings
.
TrimSpace
(
getGJSON
(
body
,
"token_type"
))
refreshToken
:=
strings
.
TrimSpace
(
getGJSON
(
body
,
"refresh_token"
))
scope
:=
strings
.
TrimSpace
(
getGJSON
(
body
,
"scope"
))
expiresIn
:=
gjson
.
Get
(
body
,
"expires_in"
)
.
Int
()
return
&
linuxDoTokenResponse
{
AccessToken
:
accessToken
,
TokenType
:
tokenType
,
ExpiresIn
:
expiresIn
,
RefreshToken
:
refreshToken
,
Scope
:
scope
,
},
true
}
values
,
err
:=
url
.
ParseQuery
(
body
)
if
err
!=
nil
{
return
nil
,
false
}
accessToken
=
strings
.
TrimSpace
(
values
.
Get
(
"access_token"
))
if
accessToken
==
""
{
return
nil
,
false
}
expiresIn
:=
int64
(
0
)
if
raw
:=
strings
.
TrimSpace
(
values
.
Get
(
"expires_in"
));
raw
!=
""
{
if
v
,
err
:=
strconv
.
ParseInt
(
raw
,
10
,
64
);
err
==
nil
{
expiresIn
=
v
}
}
return
&
linuxDoTokenResponse
{
AccessToken
:
accessToken
,
TokenType
:
strings
.
TrimSpace
(
values
.
Get
(
"token_type"
)),
ExpiresIn
:
expiresIn
,
RefreshToken
:
strings
.
TrimSpace
(
values
.
Get
(
"refresh_token"
)),
Scope
:
strings
.
TrimSpace
(
values
.
Get
(
"scope"
)),
},
true
}
func
getGJSON
(
body
string
,
path
string
)
string
{
path
=
strings
.
TrimSpace
(
path
)
if
path
==
""
{
return
""
}
res
:=
gjson
.
Get
(
body
,
path
)
if
!
res
.
Exists
()
{
return
""
}
return
res
.
String
()
}
func
truncateLogValue
(
value
string
,
maxLen
int
)
string
{
value
=
strings
.
TrimSpace
(
value
)
if
value
==
""
||
maxLen
<=
0
{
return
""
}
if
len
(
value
)
<=
maxLen
{
return
value
}
value
=
value
[
:
maxLen
]
for
!
utf8
.
ValidString
(
value
)
{
value
=
value
[
:
len
(
value
)
-
1
]
}
return
value
}
func
singleLine
(
value
string
)
string
{
value
=
strings
.
TrimSpace
(
value
)
if
value
==
""
{
return
""
}
return
strings
.
Join
(
strings
.
Fields
(
value
),
" "
)
}
func
sanitizeFrontendRedirectPath
(
path
string
)
string
{
path
=
strings
.
TrimSpace
(
path
)
if
path
==
""
{
return
""
}
if
len
(
path
)
>
linuxDoOAuthMaxRedirectLen
{
return
""
}
// 只允许同源相对路径(避免开放重定向)。
if
!
strings
.
HasPrefix
(
path
,
"/"
)
{
return
""
}
if
strings
.
HasPrefix
(
path
,
"//"
)
{
return
""
}
if
strings
.
Contains
(
path
,
"://"
)
{
return
""
}
if
strings
.
ContainsAny
(
path
,
"
\r\n
"
)
{
return
""
}
return
path
}
func
isRequestHTTPS
(
c
*
gin
.
Context
)
bool
{
if
c
.
Request
.
TLS
!=
nil
{
return
true
}
proto
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
c
.
GetHeader
(
"X-Forwarded-Proto"
)))
return
proto
==
"https"
}
func
encodeCookieValue
(
value
string
)
string
{
return
base64
.
RawURLEncoding
.
EncodeToString
([]
byte
(
value
))
}
func
decodeCookieValue
(
value
string
)
(
string
,
error
)
{
raw
,
err
:=
base64
.
RawURLEncoding
.
DecodeString
(
value
)
if
err
!=
nil
{
return
""
,
err
}
return
string
(
raw
),
nil
}
func
readCookieDecoded
(
c
*
gin
.
Context
,
name
string
)
(
string
,
error
)
{
ck
,
err
:=
c
.
Request
.
Cookie
(
name
)
if
err
!=
nil
{
return
""
,
err
}
return
decodeCookieValue
(
ck
.
Value
)
}
func
setCookie
(
c
*
gin
.
Context
,
name
string
,
value
string
,
maxAgeSec
int
,
secure
bool
)
{
http
.
SetCookie
(
c
.
Writer
,
&
http
.
Cookie
{
Name
:
name
,
Value
:
value
,
Path
:
linuxDoOAuthCookiePath
,
MaxAge
:
maxAgeSec
,
HttpOnly
:
true
,
Secure
:
secure
,
SameSite
:
http
.
SameSiteLaxMode
,
})
}
func
clearCookie
(
c
*
gin
.
Context
,
name
string
,
secure
bool
)
{
http
.
SetCookie
(
c
.
Writer
,
&
http
.
Cookie
{
Name
:
name
,
Value
:
""
,
Path
:
linuxDoOAuthCookiePath
,
MaxAge
:
-
1
,
HttpOnly
:
true
,
Secure
:
secure
,
SameSite
:
http
.
SameSiteLaxMode
,
})
}
func
truncateFragmentValue
(
value
string
)
string
{
value
=
strings
.
TrimSpace
(
value
)
if
value
==
""
{
return
""
}
if
len
(
value
)
>
linuxDoOAuthMaxFragmentValueLen
{
value
=
value
[
:
linuxDoOAuthMaxFragmentValueLen
]
for
!
utf8
.
ValidString
(
value
)
{
value
=
value
[
:
len
(
value
)
-
1
]
}
}
return
value
}
func
buildBearerAuthorization
(
tokenType
,
accessToken
string
)
(
string
,
error
)
{
tokenType
=
strings
.
TrimSpace
(
tokenType
)
if
tokenType
==
""
{
tokenType
=
"Bearer"
}
if
!
strings
.
EqualFold
(
tokenType
,
"Bearer"
)
{
return
""
,
fmt
.
Errorf
(
"unsupported token_type: %s"
,
tokenType
)
}
accessToken
=
strings
.
TrimSpace
(
accessToken
)
if
accessToken
==
""
{
return
""
,
errors
.
New
(
"missing access_token"
)
}
if
strings
.
ContainsAny
(
accessToken
,
"
\t\r\n
"
)
{
return
""
,
errors
.
New
(
"access_token contains whitespace"
)
}
return
"Bearer "
+
accessToken
,
nil
}
func
isSafeLinuxDoSubject
(
subject
string
)
bool
{
subject
=
strings
.
TrimSpace
(
subject
)
if
subject
==
""
||
len
(
subject
)
>
linuxDoOAuthMaxSubjectLen
{
return
false
}
for
_
,
r
:=
range
subject
{
switch
{
case
r
>=
'0'
&&
r
<=
'9'
:
case
r
>=
'a'
&&
r
<=
'z'
:
case
r
>=
'A'
&&
r
<=
'Z'
:
case
r
==
'_'
||
r
==
'-'
:
default
:
return
false
}
}
return
true
}
func
linuxDoSyntheticEmail
(
subject
string
)
string
{
subject
=
strings
.
TrimSpace
(
subject
)
if
subject
==
""
{
return
""
}
return
"linuxdo-"
+
subject
+
service
.
LinuxDoConnectSyntheticEmailDomain
}
backend/internal/handler/auth_linuxdo_oauth_test.go
0 → 100644
View file @
11bfc807
package
handler
import
(
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func
TestSanitizeFrontendRedirectPath
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
"/dashboard"
,
sanitizeFrontendRedirectPath
(
"/dashboard"
))
require
.
Equal
(
t
,
"/dashboard"
,
sanitizeFrontendRedirectPath
(
" /dashboard "
))
require
.
Equal
(
t
,
""
,
sanitizeFrontendRedirectPath
(
"dashboard"
))
require
.
Equal
(
t
,
""
,
sanitizeFrontendRedirectPath
(
"//evil.com"
))
require
.
Equal
(
t
,
""
,
sanitizeFrontendRedirectPath
(
"https://evil.com"
))
require
.
Equal
(
t
,
""
,
sanitizeFrontendRedirectPath
(
"/
\n
foo"
))
long
:=
"/"
+
strings
.
Repeat
(
"a"
,
linuxDoOAuthMaxRedirectLen
)
require
.
Equal
(
t
,
""
,
sanitizeFrontendRedirectPath
(
long
))
}
func
TestBuildBearerAuthorization
(
t
*
testing
.
T
)
{
auth
,
err
:=
buildBearerAuthorization
(
""
,
"token123"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"Bearer token123"
,
auth
)
auth
,
err
=
buildBearerAuthorization
(
"bearer"
,
"token123"
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"Bearer token123"
,
auth
)
_
,
err
=
buildBearerAuthorization
(
"MAC"
,
"token123"
)
require
.
Error
(
t
,
err
)
_
,
err
=
buildBearerAuthorization
(
"Bearer"
,
"token 123"
)
require
.
Error
(
t
,
err
)
}
func
TestLinuxDoParseUserInfoParsesIDAndUsername
(
t
*
testing
.
T
)
{
cfg
:=
config
.
LinuxDoConnectConfig
{
UserInfoURL
:
"https://connect.linux.do/api/user"
,
}
email
,
username
,
subject
,
err
:=
linuxDoParseUserInfo
(
`{"id":123,"username":"alice"}`
,
cfg
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"123"
,
subject
)
require
.
Equal
(
t
,
"alice"
,
username
)
require
.
Equal
(
t
,
"linuxdo-123@linuxdo-connect.invalid"
,
email
)
}
func
TestLinuxDoParseUserInfoDefaultsUsername
(
t
*
testing
.
T
)
{
cfg
:=
config
.
LinuxDoConnectConfig
{
UserInfoURL
:
"https://connect.linux.do/api/user"
,
}
email
,
username
,
subject
,
err
:=
linuxDoParseUserInfo
(
`{"id":"123"}`
,
cfg
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"123"
,
subject
)
require
.
Equal
(
t
,
"linuxdo_123"
,
username
)
require
.
Equal
(
t
,
"linuxdo-123@linuxdo-connect.invalid"
,
email
)
}
func
TestLinuxDoParseUserInfoRejectsUnsafeSubject
(
t
*
testing
.
T
)
{
cfg
:=
config
.
LinuxDoConnectConfig
{
UserInfoURL
:
"https://connect.linux.do/api/user"
,
}
_
,
_
,
_
,
err
:=
linuxDoParseUserInfo
(
`{"id":"123@456"}`
,
cfg
)
require
.
Error
(
t
,
err
)
tooLong
:=
strings
.
Repeat
(
"a"
,
linuxDoOAuthMaxSubjectLen
+
1
)
_
,
_
,
_
,
err
=
linuxDoParseUserInfo
(
`{"id":"`
+
tooLong
+
`"}`
,
cfg
)
require
.
Error
(
t
,
err
)
}
func
TestParseOAuthProviderErrorJSON
(
t
*
testing
.
T
)
{
code
,
desc
:=
parseOAuthProviderError
(
`{"error":"invalid_client","error_description":"bad secret"}`
)
require
.
Equal
(
t
,
"invalid_client"
,
code
)
require
.
Equal
(
t
,
"bad secret"
,
desc
)
}
func
TestParseOAuthProviderErrorForm
(
t
*
testing
.
T
)
{
code
,
desc
:=
parseOAuthProviderError
(
"error=invalid_request&error_description=Missing+code_verifier"
)
require
.
Equal
(
t
,
"invalid_request"
,
code
)
require
.
Equal
(
t
,
"Missing code_verifier"
,
desc
)
}
func
TestParseLinuxDoTokenResponseJSON
(
t
*
testing
.
T
)
{
token
,
ok
:=
parseLinuxDoTokenResponse
(
`{"access_token":"t1","token_type":"Bearer","expires_in":3600,"scope":"user"}`
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"t1"
,
token
.
AccessToken
)
require
.
Equal
(
t
,
"Bearer"
,
token
.
TokenType
)
require
.
Equal
(
t
,
int64
(
3600
),
token
.
ExpiresIn
)
require
.
Equal
(
t
,
"user"
,
token
.
Scope
)
}
func
TestParseLinuxDoTokenResponseForm
(
t
*
testing
.
T
)
{
token
,
ok
:=
parseLinuxDoTokenResponse
(
"access_token=t2&token_type=bearer&expires_in=60"
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"t2"
,
token
.
AccessToken
)
require
.
Equal
(
t
,
"bearer"
,
token
.
TokenType
)
require
.
Equal
(
t
,
int64
(
60
),
token
.
ExpiresIn
)
}
func
TestSingleLineStripsWhitespace
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
"hello world"
,
singleLine
(
"hello
\r\n
world"
))
require
.
Equal
(
t
,
""
,
singleLine
(
"
\n\t\r
"
))
}
backend/internal/handler/dto/mappers.go
View file @
11bfc807
...
...
@@ -85,6 +85,8 @@ func GroupFromServiceShallow(g *service.Group) *Group {
ImagePrice1K
:
g
.
ImagePrice1K
,
ImagePrice2K
:
g
.
ImagePrice2K
,
ImagePrice4K
:
g
.
ImagePrice4K
,
ClaudeCodeOnly
:
g
.
ClaudeCodeOnly
,
FallbackGroupID
:
g
.
FallbackGroupID
,
CreatedAt
:
g
.
CreatedAt
,
UpdatedAt
:
g
.
UpdatedAt
,
AccountCount
:
g
.
AccountCount
,
...
...
@@ -280,6 +282,7 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *Usag
FirstTokenMs
:
l
.
FirstTokenMs
,
ImageCount
:
l
.
ImageCount
,
ImageSize
:
l
.
ImageSize
,
UserAgent
:
l
.
UserAgent
,
CreatedAt
:
l
.
CreatedAt
,
User
:
UserFromServiceShallow
(
l
.
User
),
APIKey
:
APIKeyFromService
(
l
.
APIKey
),
...
...
backend/internal/handler/dto/settings.go
View file @
11bfc807
...
...
@@ -17,6 +17,11 @@ type SystemSettings struct {
TurnstileSiteKey
string
`json:"turnstile_site_key"`
TurnstileSecretKeyConfigured
bool
`json:"turnstile_secret_key_configured"`
LinuxDoConnectEnabled
bool
`json:"linuxdo_connect_enabled"`
LinuxDoConnectClientID
string
`json:"linuxdo_connect_client_id"`
LinuxDoConnectClientSecretConfigured
bool
`json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL
string
`json:"linuxdo_connect_redirect_url"`
SiteName
string
`json:"site_name"`
SiteLogo
string
`json:"site_logo"`
SiteSubtitle
string
`json:"site_subtitle"`
...
...
@@ -50,5 +55,6 @@ type PublicSettings struct {
APIBaseURL
string
`json:"api_base_url"`
ContactInfo
string
`json:"contact_info"`
DocURL
string
`json:"doc_url"`
LinuxDoOAuthEnabled
bool
`json:"linuxdo_oauth_enabled"`
Version
string
`json:"version"`
}
backend/internal/handler/dto/types.go
View file @
11bfc807
...
...
@@ -52,6 +52,10 @@ type Group struct {
ImagePrice2K
*
float64
`json:"image_price_2k"`
ImagePrice4K
*
float64
`json:"image_price_4k"`
// Claude Code 客户端限制
ClaudeCodeOnly
bool
`json:"claude_code_only"`
FallbackGroupID
*
int64
`json:"fallback_group_id"`
CreatedAt
time
.
Time
`json:"created_at"`
UpdatedAt
time
.
Time
`json:"updated_at"`
...
...
@@ -180,6 +184,9 @@ type UsageLog struct {
ImageCount
int
`json:"image_count"`
ImageSize
*
string
`json:"image_size"`
// User-Agent
UserAgent
*
string
`json:"user_agent"`
CreatedAt
time
.
Time
`json:"created_at"`
User
*
User
`json:"user,omitempty"`
...
...
backend/internal/handler/gateway_handler.go
View file @
11bfc807
...
...
@@ -96,6 +96,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
reqModel
:=
parsedReq
.
Model
reqStream
:=
parsedReq
.
Stream
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
SetClaudeCodeClientContext
(
c
,
body
)
// 验证 model 必填
if
reqModel
==
""
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"model is required"
)
...
...
@@ -229,7 +232,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
}
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
sessionKey
,
account
.
ID
);
err
!=
nil
{
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
...
...
@@ -357,7 +360,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
}
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
sessionKey
,
account
.
ID
);
err
!=
nil
{
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
...
...
@@ -683,6 +686,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
SetClaudeCodeClientContext
(
c
,
body
)
// 验证 model 必填
if
parsedReq
.
Model
==
""
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"model is required"
)
...
...
backend/internal/handler/gateway_helper.go
View file @
11bfc807
...
...
@@ -2,6 +2,7 @@ package handler
import
(
"context"
"encoding/json"
"fmt"
"math/rand"
"net/http"
...
...
@@ -13,6 +14,26 @@ import (
"github.com/gin-gonic/gin"
)
// claudeCodeValidator is a singleton validator for Claude Code client detection
var
claudeCodeValidator
=
service
.
NewClaudeCodeValidator
()
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
// 返回更新后的 context
func
SetClaudeCodeClientContext
(
c
*
gin
.
Context
,
body
[]
byte
)
{
// 解析请求体为 map
var
bodyMap
map
[
string
]
any
if
len
(
body
)
>
0
{
_
=
json
.
Unmarshal
(
body
,
&
bodyMap
)
}
// 验证是否为 Claude Code 客户端
isClaudeCode
:=
claudeCodeValidator
.
Validate
(
c
.
Request
,
bodyMap
)
// 更新 request context
ctx
:=
service
.
SetClaudeCodeClient
(
c
.
Request
.
Context
(),
isClaudeCode
)
c
.
Request
=
c
.
Request
.
WithContext
(
ctx
)
}
// 并发槽位等待相关常量
//
// 性能优化说明:
...
...
backend/internal/handler/gemini_v1beta_handler.go
View file @
11bfc807
...
...
@@ -203,6 +203,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 3) select account (sticky session based on request body)
parsedReq
,
_
:=
service
.
ParseGatewayRequest
(
body
)
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
SetClaudeCodeClientContext
(
c
,
body
)
sessionHash
:=
h
.
gatewayService
.
GenerateSessionHash
(
parsedReq
)
sessionKey
:=
sessionHash
if
sessionHash
!=
""
{
...
...
@@ -262,7 +266,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError
(
c
,
http
.
StatusTooManyRequests
,
err
.
Error
())
return
}
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
sessionKey
,
account
.
ID
);
err
!=
nil
{
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
...
...
backend/internal/handler/openai_gateway_handler.go
View file @
11bfc807
...
...
@@ -206,7 +206,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
}
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
sessionHash
,
account
.
ID
);
err
!=
nil
{
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionHash
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
...
...
backend/internal/handler/setting_handler.go
View file @
11bfc807
...
...
@@ -42,6 +42,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
APIBaseURL
:
settings
.
APIBaseURL
,
ContactInfo
:
settings
.
ContactInfo
,
DocURL
:
settings
.
DocURL
,
LinuxDoOAuthEnabled
:
settings
.
LinuxDoOAuthEnabled
,
Version
:
h
.
version
,
})
}
backend/internal/pkg/antigravity/client.go
View file @
11bfc807
...
...
@@ -5,8 +5,11 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"net/url"
"strings"
...
...
@@ -22,10 +25,10 @@ func resolveHost(urlStr string) string {
return
parsed
.
Host
}
// NewAPIRequest 创建 Antigravity API 请求(v1internal 端点)
func
NewAPIRequest
(
ctx
context
.
Context
,
action
,
accessToken
string
,
body
[]
byte
)
(
*
http
.
Request
,
error
)
{
// NewAPIRequest
WithURL 使用指定的 base URL
创建 Antigravity API 请求(v1internal 端点)
func
NewAPIRequest
WithURL
(
ctx
context
.
Context
,
baseURL
,
action
,
accessToken
string
,
body
[]
byte
)
(
*
http
.
Request
,
error
)
{
// 构建 URL,流式请求添加 ?alt=sse 参数
apiURL
:=
fmt
.
Sprintf
(
"%s/v1internal:%s"
,
B
aseURL
,
action
)
apiURL
:=
fmt
.
Sprintf
(
"%s/v1internal:%s"
,
b
aseURL
,
action
)
isStream
:=
action
==
"streamGenerateContent"
if
isStream
{
apiURL
+=
"?alt=sse"
...
...
@@ -53,11 +56,15 @@ func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte)
req
.
Host
=
host
}
// 注意:requestType 已在 JSON body 的 V1InternalRequest 中设置,不需要 HTTP Header
return
req
,
nil
}
// NewAPIRequest 使用默认 URL 创建 Antigravity API 请求(v1internal 端点)
// 向后兼容:仅使用默认 BaseURL
func
NewAPIRequest
(
ctx
context
.
Context
,
action
,
accessToken
string
,
body
[]
byte
)
(
*
http
.
Request
,
error
)
{
return
NewAPIRequestWithURL
(
ctx
,
BaseURL
,
action
,
accessToken
,
body
)
}
// TokenResponse Google OAuth token 响应
type
TokenResponse
struct
{
AccessToken
string
`json:"access_token"`
...
...
@@ -164,6 +171,38 @@ func NewClient(proxyURL string) *Client {
}
}
// isConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝)
func
isConnectionError
(
err
error
)
bool
{
if
err
==
nil
{
return
false
}
// 检查超时错误
var
netErr
net
.
Error
if
errors
.
As
(
err
,
&
netErr
)
&&
netErr
.
Timeout
()
{
return
true
}
// 检查连接错误(DNS 失败、连接拒绝)
var
opErr
*
net
.
OpError
if
errors
.
As
(
err
,
&
opErr
)
{
return
true
}
// 检查 URL 错误
var
urlErr
*
url
.
Error
return
errors
.
As
(
err
,
&
urlErr
)
}
// shouldFallbackToNextURL 判断是否应切换到下一个 URL
// 仅连接错误和 HTTP 429 触发 URL 降级
func
shouldFallbackToNextURL
(
err
error
,
statusCode
int
)
bool
{
if
isConnectionError
(
err
)
{
return
true
}
return
statusCode
==
http
.
StatusTooManyRequests
}
// ExchangeCode 用 authorization code 交换 token
func
(
c
*
Client
)
ExchangeCode
(
ctx
context
.
Context
,
code
,
codeVerifier
string
)
(
*
TokenResponse
,
error
)
{
params
:=
url
.
Values
{}
...
...
@@ -272,6 +311,7 @@ func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo
}
// LoadCodeAssist 获取账户信息,返回解析后的结构体和原始 JSON
// 支持 URL fallback:sandbox → daily → prod
func
(
c
*
Client
)
LoadCodeAssist
(
ctx
context
.
Context
,
accessToken
string
)
(
*
LoadCodeAssistResponse
,
map
[
string
]
any
,
error
)
{
reqBody
:=
LoadCodeAssistRequest
{}
reqBody
.
Metadata
.
IDEType
=
"ANTIGRAVITY"
...
...
@@ -281,40 +321,65 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
return
nil
,
nil
,
fmt
.
Errorf
(
"序列化请求失败: %w"
,
err
)
}
url
:=
BaseURL
+
"/v1internal:loadCodeAssist"
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
url
,
strings
.
NewReader
(
string
(
bodyBytes
))
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"创建请求失败: %w"
,
err
)
// 获取可用的 URL 列表
availableURLs
:=
DefaultURLAvailability
.
GetAvailableURLs
(
)
if
len
(
availableURLs
)
==
0
{
availableURLs
=
BaseURLs
// 所有 URL 都不可用时,重试所有
}
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
accessToken
)
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"User-Agent"
,
UserAgent
)
resp
,
err
:=
c
.
httpClient
.
Do
(
req
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"loadCodeAssist 请求失败: %w"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
var
lastErr
error
for
urlIdx
,
baseURL
:=
range
availableURLs
{
apiURL
:=
baseURL
+
"/v1internal:loadCodeAssist"
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
apiURL
,
strings
.
NewReader
(
string
(
bodyBytes
)))
if
err
!=
nil
{
lastErr
=
fmt
.
Errorf
(
"创建请求失败: %w"
,
err
)
continue
}
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
accessToken
)
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"User-Agent"
,
UserAgent
)
resp
,
err
:=
c
.
httpClient
.
Do
(
req
)
if
err
!=
nil
{
lastErr
=
fmt
.
Errorf
(
"loadCodeAssist 请求失败: %w"
,
err
)
if
shouldFallbackToNextURL
(
err
,
0
)
&&
urlIdx
<
len
(
availableURLs
)
-
1
{
DefaultURLAvailability
.
MarkUnavailable
(
baseURL
)
log
.
Printf
(
"[antigravity] loadCodeAssist URL fallback: %s -> %s"
,
baseURL
,
availableURLs
[
urlIdx
+
1
])
continue
}
return
nil
,
nil
,
lastErr
}
respBodyBytes
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"读取响应失败: %w"
,
err
)
}
respBodyBytes
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
_
=
resp
.
Body
.
Close
()
// 立即关闭,避免循环内 defer 导致的资源泄漏
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"读取响应失败: %w"
,
err
)
}
if
resp
.
StatusCode
!=
http
.
StatusOK
{
return
nil
,
nil
,
fmt
.
Errorf
(
"loadCodeAssist 失败 (HTTP %d): %s"
,
resp
.
StatusCode
,
string
(
respBodyBytes
))
}
// 检查是否需要 URL 降级
if
shouldFallbackToNextURL
(
nil
,
resp
.
StatusCode
)
&&
urlIdx
<
len
(
availableURLs
)
-
1
{
DefaultURLAvailability
.
MarkUnavailable
(
baseURL
)
log
.
Printf
(
"[antigravity] loadCodeAssist URL fallback (HTTP %d): %s -> %s"
,
resp
.
StatusCode
,
baseURL
,
availableURLs
[
urlIdx
+
1
])
continue
}
var
loadResp
LoadCodeAssistResponse
if
err
:=
json
.
Unmarshal
(
respBodyBytes
,
&
loadResp
);
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"响应解析失败: %w"
,
err
)
}
if
resp
.
StatusCode
!=
http
.
StatusOK
{
return
nil
,
nil
,
fmt
.
Errorf
(
"loadCodeAssist 失败 (HTTP %d): %s"
,
resp
.
StatusCode
,
string
(
respBodyBytes
))
}
// 解析原始 JSON 为 map
var
rawResp
map
[
string
]
any
_
=
json
.
Unmarshal
(
respBodyBytes
,
&
rawResp
)
var
loadResp
LoadCodeAssistResponse
if
err
:=
json
.
Unmarshal
(
respBodyBytes
,
&
loadResp
);
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"响应解析失败: %w"
,
err
)
}
// 解析原始 JSON 为 map
var
rawResp
map
[
string
]
any
_
=
json
.
Unmarshal
(
respBodyBytes
,
&
rawResp
)
return
&
loadResp
,
rawResp
,
nil
}
return
&
loadResp
,
rawResp
,
nil
return
nil
,
nil
,
lastErr
}
// ModelQuotaInfo 模型配额信息
...
...
@@ -339,6 +404,7 @@ type FetchAvailableModelsResponse struct {
}
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
// 支持 URL fallback:sandbox → daily → prod
func
(
c
*
Client
)
FetchAvailableModels
(
ctx
context
.
Context
,
accessToken
,
projectID
string
)
(
*
FetchAvailableModelsResponse
,
map
[
string
]
any
,
error
)
{
reqBody
:=
FetchAvailableModelsRequest
{
Project
:
projectID
}
bodyBytes
,
err
:=
json
.
Marshal
(
reqBody
)
...
...
@@ -346,38 +412,63 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
return
nil
,
nil
,
fmt
.
Errorf
(
"序列化请求失败: %w"
,
err
)
}
apiURL
:=
BaseURL
+
"/v1internal:fetchAvailableModels"
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
apiURL
,
strings
.
NewReader
(
string
(
bodyBytes
))
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"创建请求失败: %w"
,
err
)
// 获取可用的 URL 列表
availableURLs
:=
DefaultURLAvailability
.
GetAvailableURLs
(
)
if
len
(
availableURLs
)
==
0
{
availableURLs
=
BaseURLs
// 所有 URL 都不可用时,重试所有
}
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
accessToken
)
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"User-Agent"
,
UserAgent
)
resp
,
err
:=
c
.
httpClient
.
Do
(
req
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"fetchAvailableModels 请求失败: %w"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
var
lastErr
error
for
urlIdx
,
baseURL
:=
range
availableURLs
{
apiURL
:=
baseURL
+
"/v1internal:fetchAvailableModels"
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
apiURL
,
strings
.
NewReader
(
string
(
bodyBytes
)))
if
err
!=
nil
{
lastErr
=
fmt
.
Errorf
(
"创建请求失败: %w"
,
err
)
continue
}
req
.
Header
.
Set
(
"Authorization"
,
"Bearer "
+
accessToken
)
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"User-Agent"
,
UserAgent
)
resp
,
err
:=
c
.
httpClient
.
Do
(
req
)
if
err
!=
nil
{
lastErr
=
fmt
.
Errorf
(
"fetchAvailableModels 请求失败: %w"
,
err
)
if
shouldFallbackToNextURL
(
err
,
0
)
&&
urlIdx
<
len
(
availableURLs
)
-
1
{
DefaultURLAvailability
.
MarkUnavailable
(
baseURL
)
log
.
Printf
(
"[antigravity] fetchAvailableModels URL fallback: %s -> %s"
,
baseURL
,
availableURLs
[
urlIdx
+
1
])
continue
}
return
nil
,
nil
,
lastErr
}
respBodyBytes
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"读取响应失败: %w"
,
err
)
}
respBodyBytes
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
_
=
resp
.
Body
.
Close
()
// 立即关闭,避免循环内 defer 导致的资源泄漏
if
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"读取响应失败: %w"
,
err
)
}
if
resp
.
StatusCode
!=
http
.
StatusOK
{
return
nil
,
nil
,
fmt
.
Errorf
(
"fetchAvailableModels 失败 (HTTP %d): %s"
,
resp
.
StatusCode
,
string
(
respBodyBytes
))
}
// 检查是否需要 URL 降级
if
shouldFallbackToNextURL
(
nil
,
resp
.
StatusCode
)
&&
urlIdx
<
len
(
availableURLs
)
-
1
{
DefaultURLAvailability
.
MarkUnavailable
(
baseURL
)
log
.
Printf
(
"[antigravity] fetchAvailableModels URL fallback (HTTP %d): %s -> %s"
,
resp
.
StatusCode
,
baseURL
,
availableURLs
[
urlIdx
+
1
])
continue
}
var
modelsResp
FetchAvailableModelsResponse
if
err
:=
json
.
Unmarshal
(
respBodyBytes
,
&
modelsResp
);
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"响应解析失败: %w"
,
err
)
}
if
resp
.
StatusCode
!=
http
.
StatusOK
{
return
nil
,
nil
,
fmt
.
Errorf
(
"fetchAvailableModels 失败 (HTTP %d): %s"
,
resp
.
StatusCode
,
string
(
respBodyBytes
))
}
// 解析原始 JSON 为 map
var
rawResp
map
[
string
]
any
_
=
json
.
Unmarshal
(
respBodyBytes
,
&
rawResp
)
var
modelsResp
FetchAvailableModelsResponse
if
err
:=
json
.
Unmarshal
(
respBodyBytes
,
&
modelsResp
);
err
!=
nil
{
return
nil
,
nil
,
fmt
.
Errorf
(
"响应解析失败: %w"
,
err
)
}
// 解析原始 JSON 为 map
var
rawResp
map
[
string
]
any
_
=
json
.
Unmarshal
(
respBodyBytes
,
&
rawResp
)
return
&
modelsResp
,
rawResp
,
nil
}
return
&
modelsResp
,
rawResp
,
nil
return
nil
,
nil
,
lastErr
}
backend/internal/pkg/antigravity/oauth.go
View file @
11bfc807
...
...
@@ -32,17 +32,79 @@ const (
"https://www.googleapis.com/auth/cclog "
+
"https://www.googleapis.com/auth/experimentsandconfigs"
// API 端点
// 优先使用 sandbox daily URL,配额更宽松
BaseURL
=
"https://daily-cloudcode-pa.sandbox.googleapis.com"
// User-Agent(模拟官方客户端)
UserAgent
=
"antigravity/1.104.0 darwin/arm64"
// Session 过期时间
SessionTTL
=
30
*
time
.
Minute
// URL 可用性 TTL(不可用 URL 的恢复时间)
URLAvailabilityTTL
=
5
*
time
.
Minute
)
// BaseURLs 定义 Antigravity API 端点,按优先级排序
// fallback 顺序: sandbox → daily → prod
var
BaseURLs
=
[]
string
{
"https://daily-cloudcode-pa.sandbox.googleapis.com"
,
// sandbox
"https://daily-cloudcode-pa.googleapis.com"
,
// daily
"https://cloudcode-pa.googleapis.com"
,
// prod
}
// BaseURL 默认 URL(保持向后兼容)
var
BaseURL
=
BaseURLs
[
0
]
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复)
type
URLAvailability
struct
{
mu
sync
.
RWMutex
unavailable
map
[
string
]
time
.
Time
// URL -> 恢复时间
ttl
time
.
Duration
}
// DefaultURLAvailability 全局 URL 可用性管理器
var
DefaultURLAvailability
=
NewURLAvailability
(
URLAvailabilityTTL
)
// NewURLAvailability 创建 URL 可用性管理器
func
NewURLAvailability
(
ttl
time
.
Duration
)
*
URLAvailability
{
return
&
URLAvailability
{
unavailable
:
make
(
map
[
string
]
time
.
Time
),
ttl
:
ttl
,
}
}
// MarkUnavailable 标记 URL 临时不可用
func
(
u
*
URLAvailability
)
MarkUnavailable
(
url
string
)
{
u
.
mu
.
Lock
()
defer
u
.
mu
.
Unlock
()
u
.
unavailable
[
url
]
=
time
.
Now
()
.
Add
(
u
.
ttl
)
}
// IsAvailable 检查 URL 是否可用
func
(
u
*
URLAvailability
)
IsAvailable
(
url
string
)
bool
{
u
.
mu
.
RLock
()
defer
u
.
mu
.
RUnlock
()
expiry
,
exists
:=
u
.
unavailable
[
url
]
if
!
exists
{
return
true
}
return
time
.
Now
()
.
After
(
expiry
)
}
// GetAvailableURLs 返回可用的 URL 列表(保持优先级顺序)
func
(
u
*
URLAvailability
)
GetAvailableURLs
()
[]
string
{
u
.
mu
.
RLock
()
defer
u
.
mu
.
RUnlock
()
now
:=
time
.
Now
()
result
:=
make
([]
string
,
0
,
len
(
BaseURLs
))
for
_
,
url
:=
range
BaseURLs
{
expiry
,
exists
:=
u
.
unavailable
[
url
]
if
!
exists
||
now
.
After
(
expiry
)
{
result
=
append
(
result
,
url
)
}
}
return
result
}
// OAuthSession 保存 OAuth 授权流程的临时状态
type
OAuthSession
struct
{
State
string
`json:"state"`
...
...
backend/internal/pkg/ctxkey/ctxkey.go
View file @
11bfc807
...
...
@@ -7,4 +7,6 @@ type Key string
const
(
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
ForcePlatform
Key
=
"ctx_force_platform"
// IsClaudeCodeClient 是否为 Claude Code 客户端,由中间件设置
IsClaudeCodeClient
Key
=
"ctx_is_claude_code_client"
)
backend/internal/pkg/geminicli/constants.go
View file @
11bfc807
...
...
@@ -27,10 +27,9 @@ const (
// https://www.googleapis.com/auth/generative-language.retriever (often with cloud-platform).
DefaultAIStudioScopes
=
"https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever"
// DefaultScopes for Google One (personal Google accounts with Gemini access)
// Only used when a custom OAuth client is configured. When using the built-in Gemini CLI client,
// Google One uses DefaultCodeAssistScopes (same as code_assist) because the built-in client
// cannot request restricted scopes like generative-language.retriever or drive.readonly.
// DefaultGoogleOneScopes (DEPRECATED, no longer used)
// Google One now always uses the built-in Gemini CLI client with DefaultCodeAssistScopes.
// This constant is kept for backward compatibility but is not actively used.
DefaultGoogleOneScopes
=
"https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile"
// GeminiCLIRedirectURI is the redirect URI used by Gemini CLI for Code Assist OAuth.
...
...
backend/internal/pkg/geminicli/oauth.go
View file @
11bfc807
...
...
@@ -185,13 +185,9 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error
effective
.
Scopes
=
DefaultAIStudioScopes
}
case
"google_one"
:
// Google One uses built-in Gemini CLI client (same as code_assist)
// Built-in client can't request restricted scopes like generative-language.retriever
if
isBuiltinClient
{
effective
.
Scopes
=
DefaultCodeAssistScopes
}
else
{
effective
.
Scopes
=
DefaultGoogleOneScopes
}
// Google One always uses built-in Gemini CLI client (same as code_assist)
// Built-in client can't request restricted scopes like generative-language.retriever or drive.readonly
effective
.
Scopes
=
DefaultCodeAssistScopes
default
:
// Default to Code Assist scopes
effective
.
Scopes
=
DefaultCodeAssistScopes
...
...
backend/internal/pkg/geminicli/oauth_test.go
View file @
11bfc807
...
...
@@ -23,14 +23,14 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
wantErr
:
false
,
},
{
name
:
"Google One
with custom client
"
,
name
:
"Google One
always uses built-in client (even if custom credentials passed)
"
,
input
:
OAuthConfig
{
ClientID
:
"custom-client-id"
,
ClientSecret
:
"custom-client-secret"
,
},
oauthType
:
"google_one"
,
wantClientID
:
"custom-client-id"
,
wantScopes
:
Default
GoogleOneScopes
,
wantScopes
:
Default
CodeAssistScopes
,
// Uses code assist scopes even with custom client
wantErr
:
false
,
},
{
...
...
backend/internal/repository/account_repo.go
View file @
11bfc807
...
...
@@ -886,6 +886,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
args
=
append
(
args
,
*
updates
.
Status
)
idx
++
}
if
updates
.
Schedulable
!=
nil
{
setClauses
=
append
(
setClauses
,
"schedulable = $"
+
itoa
(
idx
))
args
=
append
(
args
,
*
updates
.
Schedulable
)
idx
++
}
// JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。
if
len
(
updates
.
Credentials
)
>
0
{
payload
,
err
:=
json
.
Marshal
(
updates
.
Credentials
)
...
...
backend/internal/repository/api_key_repo.go
View file @
11bfc807
...
...
@@ -325,6 +325,8 @@ func groupEntityToService(g *dbent.Group) *service.Group {
ImagePrice2K
:
g
.
ImagePrice2k
,
ImagePrice4K
:
g
.
ImagePrice4k
,
DefaultValidityDays
:
g
.
DefaultValidityDays
,
ClaudeCodeOnly
:
g
.
ClaudeCodeOnly
,
FallbackGroupID
:
g
.
FallbackGroupID
,
CreatedAt
:
g
.
CreatedAt
,
UpdatedAt
:
g
.
UpdatedAt
,
}
...
...
backend/internal/repository/gateway_cache.go
View file @
11bfc807
...
...
@@ -2,6 +2,7 @@ package repository
import
(
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
...
...
@@ -18,17 +19,23 @@ func NewGatewayCache(rdb *redis.Client) service.GatewayCache {
return
&
gatewayCache
{
rdb
:
rdb
}
}
func
(
c
*
gatewayCache
)
GetSessionAccountID
(
ctx
context
.
Context
,
sessionHash
string
)
(
int64
,
error
)
{
key
:=
stickySessionPrefix
+
sessionHash
// buildSessionKey 构建 session key,包含 groupID 实现分组隔离
// 格式: sticky_session:{groupID}:{sessionHash}
func
buildSessionKey
(
groupID
int64
,
sessionHash
string
)
string
{
return
fmt
.
Sprintf
(
"%s%d:%s"
,
stickySessionPrefix
,
groupID
,
sessionHash
)
}
func
(
c
*
gatewayCache
)
GetSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
)
(
int64
,
error
)
{
key
:=
buildSessionKey
(
groupID
,
sessionHash
)
return
c
.
rdb
.
Get
(
ctx
,
key
)
.
Int64
()
}
func
(
c
*
gatewayCache
)
SetSessionAccountID
(
ctx
context
.
Context
,
sessionHash
string
,
accountID
int64
,
ttl
time
.
Duration
)
error
{
key
:=
sticky
Session
Prefix
+
sessionHash
func
(
c
*
gatewayCache
)
SetSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
,
accountID
int64
,
ttl
time
.
Duration
)
error
{
key
:=
build
Session
Key
(
groupID
,
sessionHash
)
return
c
.
rdb
.
Set
(
ctx
,
key
,
accountID
,
ttl
)
.
Err
()
}
func
(
c
*
gatewayCache
)
RefreshSessionTTL
(
ctx
context
.
Context
,
sessionHash
string
,
ttl
time
.
Duration
)
error
{
key
:=
sticky
Session
Prefix
+
sessionHash
func
(
c
*
gatewayCache
)
RefreshSessionTTL
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
,
ttl
time
.
Duration
)
error
{
key
:=
build
Session
Key
(
groupID
,
sessionHash
)
return
c
.
rdb
.
Expire
(
ctx
,
key
,
ttl
)
.
Err
()
}
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment