Commit 11bfc807 authored by song's avatar song
Browse files

merge upstream/main

parents c2a6ca8d 62dc0b95
...@@ -15,14 +15,16 @@ type AuthHandler struct { ...@@ -15,14 +15,16 @@ type AuthHandler struct {
cfg *config.Config cfg *config.Config
authService *service.AuthService authService *service.AuthService
userService *service.UserService userService *service.UserService
settingSvc *service.SettingService
} }
// NewAuthHandler creates a new AuthHandler // NewAuthHandler creates a new AuthHandler
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService) *AuthHandler { func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService) *AuthHandler {
return &AuthHandler{ return &AuthHandler{
cfg: cfg, cfg: cfg,
authService: authService, authService: authService,
userService: userService, userService: userService,
settingSvc: settingService,
} }
} }
......
package handler
import (
"context"
"encoding/base64"
"errors"
"fmt"
"log"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/imroc/req/v3"
"github.com/tidwall/gjson"
)
const (
linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo"
linuxDoOAuthStateCookieName = "linuxdo_oauth_state"
linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier"
linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect"
linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes
linuxDoOAuthDefaultRedirectTo = "/dashboard"
linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback"
linuxDoOAuthMaxRedirectLen = 2048
linuxDoOAuthMaxFragmentValueLen = 512
linuxDoOAuthMaxSubjectLen = 64 - len("linuxdo-")
)
type linuxDoTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
}
type linuxDoTokenExchangeError struct {
StatusCode int
ProviderError string
ProviderDescription string
Body string
}
func (e *linuxDoTokenExchangeError) Error() string {
if e == nil {
return ""
}
parts := []string{fmt.Sprintf("token exchange status=%d", e.StatusCode)}
if strings.TrimSpace(e.ProviderError) != "" {
parts = append(parts, "error="+strings.TrimSpace(e.ProviderError))
}
if strings.TrimSpace(e.ProviderDescription) != "" {
parts = append(parts, "error_description="+strings.TrimSpace(e.ProviderDescription))
}
return strings.Join(parts, " ")
}
// LinuxDoOAuthStart 启动 LinuxDo Connect OAuth 登录流程。
// GET /api/v1/auth/oauth/linuxdo/start?redirect=/dashboard
func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) {
cfg, err := h.getLinuxDoOAuthConfig(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
state, err := oauth.GenerateState()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
return
}
redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
if redirectTo == "" {
redirectTo = linuxDoOAuthDefaultRedirectTo
}
secureCookie := isRequestHTTPS(c)
setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie)
setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie)
codeChallenge := ""
if cfg.UsePKCE {
verifier, err := oauth.GenerateCodeVerifier()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err))
return
}
codeChallenge = oauth.GenerateCodeChallenge(verifier)
setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie)
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured"))
return
}
authURL, err := buildLinuxDoAuthorizeURL(cfg, state, codeChallenge, redirectURI)
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
return
}
c.Redirect(http.StatusFound, authURL)
}
// LinuxDoOAuthCallback 处理 OAuth 回调:创建/登录用户,然后重定向到前端。
// GET /api/v1/auth/oauth/linuxdo/callback?code=...&state=...
func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
cfg, cfgErr := h.getLinuxDoOAuthConfig(c.Request.Context())
if cfgErr != nil {
response.ErrorFrom(c, cfgErr)
return
}
frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL)
if frontendCallback == "" {
frontendCallback = linuxDoOAuthDefaultFrontendCB
}
if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
return
}
code := strings.TrimSpace(c.Query("code"))
state := strings.TrimSpace(c.Query("state"))
if code == "" || state == "" {
redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
return
}
secureCookie := isRequestHTTPS(c)
defer func() {
clearCookie(c, linuxDoOAuthStateCookieName, secureCookie)
clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie)
clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie)
}()
expectedState, err := readCookieDecoded(c, linuxDoOAuthStateCookieName)
if err != nil || expectedState == "" || state != expectedState {
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
return
}
redirectTo, _ := readCookieDecoded(c, linuxDoOAuthRedirectCookie)
redirectTo = sanitizeFrontendRedirectPath(redirectTo)
if redirectTo == "" {
redirectTo = linuxDoOAuthDefaultRedirectTo
}
codeVerifier := ""
if cfg.UsePKCE {
codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie)
if codeVerifier == "" {
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
return
}
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
redirectOAuthError(c, frontendCallback, "config_error", "oauth redirect url not configured", "")
return
}
tokenResp, err := linuxDoExchangeCode(c.Request.Context(), cfg, code, redirectURI, codeVerifier)
if err != nil {
description := ""
var exchangeErr *linuxDoTokenExchangeError
if errors.As(err, &exchangeErr) && exchangeErr != nil {
log.Printf(
"[LinuxDo OAuth] token exchange failed: status=%d provider_error=%q provider_description=%q body=%s",
exchangeErr.StatusCode,
exchangeErr.ProviderError,
exchangeErr.ProviderDescription,
truncateLogValue(exchangeErr.Body, 2048),
)
description = exchangeErr.Error()
} else {
log.Printf("[LinuxDo OAuth] token exchange failed: %v", err)
description = err.Error()
}
redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", singleLine(description))
return
}
email, username, subject, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp)
if err != nil {
log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err)
redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "")
return
}
// 安全考虑:不要把第三方返回的 email 直接映射到本地账号(可能与本地邮箱用户冲突导致账号被接管)。
// 统一使用基于 subject 的稳定合成邮箱来做账号绑定。
if subject != "" {
email = linuxDoSyntheticEmail(subject)
}
jwtToken, _, err := h.authService.LoginOrRegisterOAuth(c.Request.Context(), email, username)
if err != nil {
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
return
}
fragment := url.Values{}
fragment.Set("access_token", jwtToken)
fragment.Set("token_type", "Bearer")
fragment.Set("redirect", redirectTo)
redirectWithFragment(c, frontendCallback, fragment)
}
func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
if h != nil && h.settingSvc != nil {
return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx)
}
if h == nil || h.cfg == nil {
return config.LinuxDoConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
}
if !h.cfg.LinuxDo.Enabled {
return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
}
return h.cfg.LinuxDo, nil
}
func linuxDoExchangeCode(
ctx context.Context,
cfg config.LinuxDoConnectConfig,
code string,
redirectURI string,
codeVerifier string,
) (*linuxDoTokenResponse, error) {
client := req.C().SetTimeout(30 * time.Second)
form := url.Values{}
form.Set("grant_type", "authorization_code")
form.Set("client_id", cfg.ClientID)
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
if cfg.UsePKCE {
form.Set("code_verifier", codeVerifier)
}
r := client.R().
SetContext(ctx).
SetHeader("Accept", "application/json")
switch strings.ToLower(strings.TrimSpace(cfg.TokenAuthMethod)) {
case "", "client_secret_post":
form.Set("client_secret", cfg.ClientSecret)
case "client_secret_basic":
r.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
case "none":
default:
return nil, fmt.Errorf("unsupported token_auth_method: %s", cfg.TokenAuthMethod)
}
resp, err := r.SetFormDataFromValues(form).Post(cfg.TokenURL)
if err != nil {
return nil, fmt.Errorf("request token: %w", err)
}
body := strings.TrimSpace(resp.String())
if !resp.IsSuccessState() {
providerErr, providerDesc := parseOAuthProviderError(body)
return nil, &linuxDoTokenExchangeError{
StatusCode: resp.StatusCode,
ProviderError: providerErr,
ProviderDescription: providerDesc,
Body: body,
}
}
tokenResp, ok := parseLinuxDoTokenResponse(body)
if !ok || strings.TrimSpace(tokenResp.AccessToken) == "" {
return nil, &linuxDoTokenExchangeError{
StatusCode: resp.StatusCode,
Body: body,
}
}
if strings.TrimSpace(tokenResp.TokenType) == "" {
tokenResp.TokenType = "Bearer"
}
return tokenResp, nil
}
func linuxDoFetchUserInfo(
ctx context.Context,
cfg config.LinuxDoConnectConfig,
token *linuxDoTokenResponse,
) (email string, username string, subject string, err error) {
client := req.C().SetTimeout(30 * time.Second)
authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken)
if err != nil {
return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err)
}
resp, err := client.R().
SetContext(ctx).
SetHeader("Accept", "application/json").
SetHeader("Authorization", authorization).
Get(cfg.UserInfoURL)
if err != nil {
return "", "", "", fmt.Errorf("request userinfo: %w", err)
}
if !resp.IsSuccessState() {
return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode)
}
return linuxDoParseUserInfo(resp.String(), cfg)
}
func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) {
email = firstNonEmpty(
getGJSON(body, cfg.UserInfoEmailPath),
getGJSON(body, "email"),
getGJSON(body, "user.email"),
getGJSON(body, "data.email"),
getGJSON(body, "attributes.email"),
)
username = firstNonEmpty(
getGJSON(body, cfg.UserInfoUsernamePath),
getGJSON(body, "username"),
getGJSON(body, "preferred_username"),
getGJSON(body, "name"),
getGJSON(body, "user.username"),
getGJSON(body, "user.name"),
)
subject = firstNonEmpty(
getGJSON(body, cfg.UserInfoIDPath),
getGJSON(body, "sub"),
getGJSON(body, "id"),
getGJSON(body, "user_id"),
getGJSON(body, "uid"),
getGJSON(body, "user.id"),
)
subject = strings.TrimSpace(subject)
if subject == "" {
return "", "", "", errors.New("userinfo missing id field")
}
if !isSafeLinuxDoSubject(subject) {
return "", "", "", errors.New("userinfo returned invalid id field")
}
email = strings.TrimSpace(email)
if email == "" {
// LinuxDo Connect 的 userinfo 可能不提供 email。为兼容现有用户模型(email 必填且唯一),使用稳定的合成邮箱。
email = linuxDoSyntheticEmail(subject)
}
username = strings.TrimSpace(username)
if username == "" {
username = "linuxdo_" + subject
}
return email, username, subject, nil
}
func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) {
u, err := url.Parse(cfg.AuthorizeURL)
if err != nil {
return "", fmt.Errorf("parse authorize_url: %w", err)
}
q := u.Query()
q.Set("response_type", "code")
q.Set("client_id", cfg.ClientID)
q.Set("redirect_uri", redirectURI)
if strings.TrimSpace(cfg.Scopes) != "" {
q.Set("scope", cfg.Scopes)
}
q.Set("state", state)
if cfg.UsePKCE {
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
}
u.RawQuery = q.Encode()
return u.String(), nil
}
func redirectOAuthError(c *gin.Context, frontendCallback string, code string, message string, description string) {
fragment := url.Values{}
fragment.Set("error", truncateFragmentValue(code))
if strings.TrimSpace(message) != "" {
fragment.Set("error_message", truncateFragmentValue(message))
}
if strings.TrimSpace(description) != "" {
fragment.Set("error_description", truncateFragmentValue(description))
}
redirectWithFragment(c, frontendCallback, fragment)
}
func redirectWithFragment(c *gin.Context, frontendCallback string, fragment url.Values) {
u, err := url.Parse(frontendCallback)
if err != nil {
// 兜底:尽力跳转到默认页面,避免卡死在回调页。
c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
return
}
if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") {
c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
return
}
u.Fragment = fragment.Encode()
c.Header("Cache-Control", "no-store")
c.Header("Pragma", "no-cache")
c.Redirect(http.StatusFound, u.String())
}
func firstNonEmpty(values ...string) string {
for _, v := range values {
v = strings.TrimSpace(v)
if v != "" {
return v
}
}
return ""
}
func parseOAuthProviderError(body string) (providerErr string, providerDesc string) {
body = strings.TrimSpace(body)
if body == "" {
return "", ""
}
providerErr = firstNonEmpty(
getGJSON(body, "error"),
getGJSON(body, "code"),
getGJSON(body, "error.code"),
)
providerDesc = firstNonEmpty(
getGJSON(body, "error_description"),
getGJSON(body, "error.message"),
getGJSON(body, "message"),
getGJSON(body, "detail"),
)
if providerErr != "" || providerDesc != "" {
return providerErr, providerDesc
}
values, err := url.ParseQuery(body)
if err != nil {
return "", ""
}
providerErr = firstNonEmpty(values.Get("error"), values.Get("code"))
providerDesc = firstNonEmpty(values.Get("error_description"), values.Get("error_message"), values.Get("message"))
return providerErr, providerDesc
}
func parseLinuxDoTokenResponse(body string) (*linuxDoTokenResponse, bool) {
body = strings.TrimSpace(body)
if body == "" {
return nil, false
}
accessToken := strings.TrimSpace(getGJSON(body, "access_token"))
if accessToken != "" {
tokenType := strings.TrimSpace(getGJSON(body, "token_type"))
refreshToken := strings.TrimSpace(getGJSON(body, "refresh_token"))
scope := strings.TrimSpace(getGJSON(body, "scope"))
expiresIn := gjson.Get(body, "expires_in").Int()
return &linuxDoTokenResponse{
AccessToken: accessToken,
TokenType: tokenType,
ExpiresIn: expiresIn,
RefreshToken: refreshToken,
Scope: scope,
}, true
}
values, err := url.ParseQuery(body)
if err != nil {
return nil, false
}
accessToken = strings.TrimSpace(values.Get("access_token"))
if accessToken == "" {
return nil, false
}
expiresIn := int64(0)
if raw := strings.TrimSpace(values.Get("expires_in")); raw != "" {
if v, err := strconv.ParseInt(raw, 10, 64); err == nil {
expiresIn = v
}
}
return &linuxDoTokenResponse{
AccessToken: accessToken,
TokenType: strings.TrimSpace(values.Get("token_type")),
ExpiresIn: expiresIn,
RefreshToken: strings.TrimSpace(values.Get("refresh_token")),
Scope: strings.TrimSpace(values.Get("scope")),
}, true
}
func getGJSON(body string, path string) string {
path = strings.TrimSpace(path)
if path == "" {
return ""
}
res := gjson.Get(body, path)
if !res.Exists() {
return ""
}
return res.String()
}
func truncateLogValue(value string, maxLen int) string {
value = strings.TrimSpace(value)
if value == "" || maxLen <= 0 {
return ""
}
if len(value) <= maxLen {
return value
}
value = value[:maxLen]
for !utf8.ValidString(value) {
value = value[:len(value)-1]
}
return value
}
func singleLine(value string) string {
value = strings.TrimSpace(value)
if value == "" {
return ""
}
return strings.Join(strings.Fields(value), " ")
}
func sanitizeFrontendRedirectPath(path string) string {
path = strings.TrimSpace(path)
if path == "" {
return ""
}
if len(path) > linuxDoOAuthMaxRedirectLen {
return ""
}
// 只允许同源相对路径(避免开放重定向)。
if !strings.HasPrefix(path, "/") {
return ""
}
if strings.HasPrefix(path, "//") {
return ""
}
if strings.Contains(path, "://") {
return ""
}
if strings.ContainsAny(path, "\r\n") {
return ""
}
return path
}
func isRequestHTTPS(c *gin.Context) bool {
if c.Request.TLS != nil {
return true
}
proto := strings.ToLower(strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")))
return proto == "https"
}
func encodeCookieValue(value string) string {
return base64.RawURLEncoding.EncodeToString([]byte(value))
}
func decodeCookieValue(value string) (string, error) {
raw, err := base64.RawURLEncoding.DecodeString(value)
if err != nil {
return "", err
}
return string(raw), nil
}
func readCookieDecoded(c *gin.Context, name string) (string, error) {
ck, err := c.Request.Cookie(name)
if err != nil {
return "", err
}
return decodeCookieValue(ck.Value)
}
func setCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: name,
Value: value,
Path: linuxDoOAuthCookiePath,
MaxAge: maxAgeSec,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func clearCookie(c *gin.Context, name string, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: name,
Value: "",
Path: linuxDoOAuthCookiePath,
MaxAge: -1,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func truncateFragmentValue(value string) string {
value = strings.TrimSpace(value)
if value == "" {
return ""
}
if len(value) > linuxDoOAuthMaxFragmentValueLen {
value = value[:linuxDoOAuthMaxFragmentValueLen]
for !utf8.ValidString(value) {
value = value[:len(value)-1]
}
}
return value
}
func buildBearerAuthorization(tokenType, accessToken string) (string, error) {
tokenType = strings.TrimSpace(tokenType)
if tokenType == "" {
tokenType = "Bearer"
}
if !strings.EqualFold(tokenType, "Bearer") {
return "", fmt.Errorf("unsupported token_type: %s", tokenType)
}
accessToken = strings.TrimSpace(accessToken)
if accessToken == "" {
return "", errors.New("missing access_token")
}
if strings.ContainsAny(accessToken, " \t\r\n") {
return "", errors.New("access_token contains whitespace")
}
return "Bearer " + accessToken, nil
}
func isSafeLinuxDoSubject(subject string) bool {
subject = strings.TrimSpace(subject)
if subject == "" || len(subject) > linuxDoOAuthMaxSubjectLen {
return false
}
for _, r := range subject {
switch {
case r >= '0' && r <= '9':
case r >= 'a' && r <= 'z':
case r >= 'A' && r <= 'Z':
case r == '_' || r == '-':
default:
return false
}
}
return true
}
func linuxDoSyntheticEmail(subject string) string {
subject = strings.TrimSpace(subject)
if subject == "" {
return ""
}
return "linuxdo-" + subject + service.LinuxDoConnectSyntheticEmailDomain
}
package handler
import (
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestSanitizeFrontendRedirectPath(t *testing.T) {
require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath("/dashboard"))
require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath(" /dashboard "))
require.Equal(t, "", sanitizeFrontendRedirectPath("dashboard"))
require.Equal(t, "", sanitizeFrontendRedirectPath("//evil.com"))
require.Equal(t, "", sanitizeFrontendRedirectPath("https://evil.com"))
require.Equal(t, "", sanitizeFrontendRedirectPath("/\nfoo"))
long := "/" + strings.Repeat("a", linuxDoOAuthMaxRedirectLen)
require.Equal(t, "", sanitizeFrontendRedirectPath(long))
}
func TestBuildBearerAuthorization(t *testing.T) {
auth, err := buildBearerAuthorization("", "token123")
require.NoError(t, err)
require.Equal(t, "Bearer token123", auth)
auth, err = buildBearerAuthorization("bearer", "token123")
require.NoError(t, err)
require.Equal(t, "Bearer token123", auth)
_, err = buildBearerAuthorization("MAC", "token123")
require.Error(t, err)
_, err = buildBearerAuthorization("Bearer", "token 123")
require.Error(t, err)
}
func TestLinuxDoParseUserInfoParsesIDAndUsername(t *testing.T) {
cfg := config.LinuxDoConnectConfig{
UserInfoURL: "https://connect.linux.do/api/user",
}
email, username, subject, err := linuxDoParseUserInfo(`{"id":123,"username":"alice"}`, cfg)
require.NoError(t, err)
require.Equal(t, "123", subject)
require.Equal(t, "alice", username)
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
}
func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) {
cfg := config.LinuxDoConnectConfig{
UserInfoURL: "https://connect.linux.do/api/user",
}
email, username, subject, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg)
require.NoError(t, err)
require.Equal(t, "123", subject)
require.Equal(t, "linuxdo_123", username)
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
}
func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) {
cfg := config.LinuxDoConnectConfig{
UserInfoURL: "https://connect.linux.do/api/user",
}
_, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg)
require.Error(t, err)
tooLong := strings.Repeat("a", linuxDoOAuthMaxSubjectLen+1)
_, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg)
require.Error(t, err)
}
func TestParseOAuthProviderErrorJSON(t *testing.T) {
code, desc := parseOAuthProviderError(`{"error":"invalid_client","error_description":"bad secret"}`)
require.Equal(t, "invalid_client", code)
require.Equal(t, "bad secret", desc)
}
func TestParseOAuthProviderErrorForm(t *testing.T) {
code, desc := parseOAuthProviderError("error=invalid_request&error_description=Missing+code_verifier")
require.Equal(t, "invalid_request", code)
require.Equal(t, "Missing code_verifier", desc)
}
func TestParseLinuxDoTokenResponseJSON(t *testing.T) {
token, ok := parseLinuxDoTokenResponse(`{"access_token":"t1","token_type":"Bearer","expires_in":3600,"scope":"user"}`)
require.True(t, ok)
require.Equal(t, "t1", token.AccessToken)
require.Equal(t, "Bearer", token.TokenType)
require.Equal(t, int64(3600), token.ExpiresIn)
require.Equal(t, "user", token.Scope)
}
func TestParseLinuxDoTokenResponseForm(t *testing.T) {
token, ok := parseLinuxDoTokenResponse("access_token=t2&token_type=bearer&expires_in=60")
require.True(t, ok)
require.Equal(t, "t2", token.AccessToken)
require.Equal(t, "bearer", token.TokenType)
require.Equal(t, int64(60), token.ExpiresIn)
}
func TestSingleLineStripsWhitespace(t *testing.T) {
require.Equal(t, "hello world", singleLine("hello\r\nworld"))
require.Equal(t, "", singleLine("\n\t\r"))
}
...@@ -85,6 +85,8 @@ func GroupFromServiceShallow(g *service.Group) *Group { ...@@ -85,6 +85,8 @@ func GroupFromServiceShallow(g *service.Group) *Group {
ImagePrice1K: g.ImagePrice1K, ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K, ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K, ImagePrice4K: g.ImagePrice4K,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
CreatedAt: g.CreatedAt, CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt, UpdatedAt: g.UpdatedAt,
AccountCount: g.AccountCount, AccountCount: g.AccountCount,
...@@ -280,6 +282,7 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *Usag ...@@ -280,6 +282,7 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *Usag
FirstTokenMs: l.FirstTokenMs, FirstTokenMs: l.FirstTokenMs,
ImageCount: l.ImageCount, ImageCount: l.ImageCount,
ImageSize: l.ImageSize, ImageSize: l.ImageSize,
UserAgent: l.UserAgent,
CreatedAt: l.CreatedAt, CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User), User: UserFromServiceShallow(l.User),
APIKey: APIKeyFromService(l.APIKey), APIKey: APIKeyFromService(l.APIKey),
......
...@@ -17,6 +17,11 @@ type SystemSettings struct { ...@@ -17,6 +17,11 @@ type SystemSettings struct {
TurnstileSiteKey string `json:"turnstile_site_key"` TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"` TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"`
LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"`
LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"`
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
SiteName string `json:"site_name"` SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"` SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"` SiteSubtitle string `json:"site_subtitle"`
...@@ -50,5 +55,6 @@ type PublicSettings struct { ...@@ -50,5 +55,6 @@ type PublicSettings struct {
APIBaseURL string `json:"api_base_url"` APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"` ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"` DocURL string `json:"doc_url"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version"` Version string `json:"version"`
} }
...@@ -52,6 +52,10 @@ type Group struct { ...@@ -52,6 +52,10 @@ type Group struct {
ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"` ImagePrice4K *float64 `json:"image_price_4k"`
// Claude Code 客户端限制
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
...@@ -180,6 +184,9 @@ type UsageLog struct { ...@@ -180,6 +184,9 @@ type UsageLog struct {
ImageCount int `json:"image_count"` ImageCount int `json:"image_count"`
ImageSize *string `json:"image_size"` ImageSize *string `json:"image_size"`
// User-Agent
UserAgent *string `json:"user_agent"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"` User *User `json:"user,omitempty"`
......
...@@ -96,6 +96,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -96,6 +96,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
reqModel := parsedReq.Model reqModel := parsedReq.Model
reqStream := parsedReq.Stream reqStream := parsedReq.Stream
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
SetClaudeCodeClientContext(c, body)
// 验证 model 必填 // 验证 model 必填
if reqModel == "" { if reqModel == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
...@@ -229,7 +232,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -229,7 +232,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
return return
} }
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil { if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err) log.Printf("Bind sticky session failed: %v", err)
} }
} }
...@@ -357,7 +360,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -357,7 +360,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
return return
} }
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil { if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err) log.Printf("Bind sticky session failed: %v", err)
} }
} }
...@@ -683,6 +686,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { ...@@ -683,6 +686,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return return
} }
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
SetClaudeCodeClientContext(c, body)
// 验证 model 必填 // 验证 model 必填
if parsedReq.Model == "" { if parsedReq.Model == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
......
...@@ -2,6 +2,7 @@ package handler ...@@ -2,6 +2,7 @@ package handler
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"math/rand" "math/rand"
"net/http" "net/http"
...@@ -13,6 +14,26 @@ import ( ...@@ -13,6 +14,26 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// claudeCodeValidator is a singleton validator for Claude Code client detection
var claudeCodeValidator = service.NewClaudeCodeValidator()
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
// 返回更新后的 context
func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
// 解析请求体为 map
var bodyMap map[string]any
if len(body) > 0 {
_ = json.Unmarshal(body, &bodyMap)
}
// 验证是否为 Claude Code 客户端
isClaudeCode := claudeCodeValidator.Validate(c.Request, bodyMap)
// 更新 request context
ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode)
c.Request = c.Request.WithContext(ctx)
}
// 并发槽位等待相关常量 // 并发槽位等待相关常量
// //
// 性能优化说明: // 性能优化说明:
......
...@@ -203,6 +203,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -203,6 +203,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 3) select account (sticky session based on request body) // 3) select account (sticky session based on request body)
parsedReq, _ := service.ParseGatewayRequest(body) parsedReq, _ := service.ParseGatewayRequest(body)
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
SetClaudeCodeClientContext(c, body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
sessionKey := sessionHash sessionKey := sessionHash
if sessionHash != "" { if sessionHash != "" {
...@@ -262,7 +266,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -262,7 +266,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError(c, http.StatusTooManyRequests, err.Error()) googleError(c, http.StatusTooManyRequests, err.Error())
return return
} }
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil { if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err) log.Printf("Bind sticky session failed: %v", err)
} }
} }
......
...@@ -206,7 +206,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -206,7 +206,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
return return
} }
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil { if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err) log.Printf("Bind sticky session failed: %v", err)
} }
} }
......
...@@ -42,6 +42,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { ...@@ -42,6 +42,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
APIBaseURL: settings.APIBaseURL, APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo, ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL, DocURL: settings.DocURL,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: h.version, Version: h.version,
}) })
} }
...@@ -5,8 +5,11 @@ import ( ...@@ -5,8 +5,11 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"log"
"net"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
...@@ -22,10 +25,10 @@ func resolveHost(urlStr string) string { ...@@ -22,10 +25,10 @@ func resolveHost(urlStr string) string {
return parsed.Host return parsed.Host
} }
// NewAPIRequest 创建 Antigravity API 请求(v1internal 端点) // NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) (*http.Request, error) { func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
// 构建 URL,流式请求添加 ?alt=sse 参数 // 构建 URL,流式请求添加 ?alt=sse 参数
apiURL := fmt.Sprintf("%s/v1internal:%s", BaseURL, action) apiURL := fmt.Sprintf("%s/v1internal:%s", baseURL, action)
isStream := action == "streamGenerateContent" isStream := action == "streamGenerateContent"
if isStream { if isStream {
apiURL += "?alt=sse" apiURL += "?alt=sse"
...@@ -53,11 +56,15 @@ func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) ...@@ -53,11 +56,15 @@ func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte)
req.Host = host req.Host = host
} }
// 注意:requestType 已在 JSON body 的 V1InternalRequest 中设置,不需要 HTTP Header
return req, nil return req, nil
} }
// NewAPIRequest 使用默认 URL 创建 Antigravity API 请求(v1internal 端点)
// 向后兼容:仅使用默认 BaseURL
func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) (*http.Request, error) {
return NewAPIRequestWithURL(ctx, BaseURL, action, accessToken, body)
}
// TokenResponse Google OAuth token 响应 // TokenResponse Google OAuth token 响应
type TokenResponse struct { type TokenResponse struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
...@@ -164,6 +171,38 @@ func NewClient(proxyURL string) *Client { ...@@ -164,6 +171,38 @@ func NewClient(proxyURL string) *Client {
} }
} }
// isConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝)
func isConnectionError(err error) bool {
if err == nil {
return false
}
// 检查超时错误
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return true
}
// 检查连接错误(DNS 失败、连接拒绝)
var opErr *net.OpError
if errors.As(err, &opErr) {
return true
}
// 检查 URL 错误
var urlErr *url.Error
return errors.As(err, &urlErr)
}
// shouldFallbackToNextURL 判断是否应切换到下一个 URL
// 仅连接错误和 HTTP 429 触发 URL 降级
func shouldFallbackToNextURL(err error, statusCode int) bool {
if isConnectionError(err) {
return true
}
return statusCode == http.StatusTooManyRequests
}
// ExchangeCode 用 authorization code 交换 token // ExchangeCode 用 authorization code 交换 token
func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) { func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) {
params := url.Values{} params := url.Values{}
...@@ -272,6 +311,7 @@ func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo ...@@ -272,6 +311,7 @@ func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo
} }
// LoadCodeAssist 获取账户信息,返回解析后的结构体和原始 JSON // LoadCodeAssist 获取账户信息,返回解析后的结构体和原始 JSON
// 支持 URL fallback:sandbox → daily → prod
func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadCodeAssistResponse, map[string]any, error) { func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadCodeAssistResponse, map[string]any, error) {
reqBody := LoadCodeAssistRequest{} reqBody := LoadCodeAssistRequest{}
reqBody.Metadata.IDEType = "ANTIGRAVITY" reqBody.Metadata.IDEType = "ANTIGRAVITY"
...@@ -281,10 +321,19 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC ...@@ -281,10 +321,19 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
return nil, nil, fmt.Errorf("序列化请求失败: %w", err) return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
} }
url := BaseURL + "/v1internal:loadCodeAssist" // 获取可用的 URL 列表
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(string(bodyBytes))) availableURLs := DefaultURLAvailability.GetAvailableURLs()
if len(availableURLs) == 0 {
availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有
}
var lastErr error
for urlIdx, baseURL := range availableURLs {
apiURL := baseURL + "/v1internal:loadCodeAssist"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes)))
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("创建请求失败: %w", err) lastErr = fmt.Errorf("创建请求失败: %w", err)
continue
} }
req.Header.Set("Authorization", "Bearer "+accessToken) req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
...@@ -292,15 +341,28 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC ...@@ -292,15 +341,28 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
resp, err := c.httpClient.Do(req) resp, err := c.httpClient.Do(req)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("loadCodeAssist 请求失败: %w", err) lastErr = fmt.Errorf("loadCodeAssist 请求失败: %w", err)
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity] loadCodeAssist URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
continue
}
return nil, nil, lastErr
} }
defer func() { _ = resp.Body.Close() }()
respBodyBytes, err := io.ReadAll(resp.Body) respBodyBytes, err := io.ReadAll(resp.Body)
_ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("读取响应失败: %w", err) return nil, nil, fmt.Errorf("读取响应失败: %w", err)
} }
// 检查是否需要 URL 降级
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity] loadCodeAssist URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
continue
}
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("loadCodeAssist 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes)) return nil, nil, fmt.Errorf("loadCodeAssist 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
} }
...@@ -315,6 +377,9 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC ...@@ -315,6 +377,9 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
_ = json.Unmarshal(respBodyBytes, &rawResp) _ = json.Unmarshal(respBodyBytes, &rawResp)
return &loadResp, rawResp, nil return &loadResp, rawResp, nil
}
return nil, nil, lastErr
} }
// ModelQuotaInfo 模型配额信息 // ModelQuotaInfo 模型配额信息
...@@ -339,6 +404,7 @@ type FetchAvailableModelsResponse struct { ...@@ -339,6 +404,7 @@ type FetchAvailableModelsResponse struct {
} }
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON // FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
// 支持 URL fallback:sandbox → daily → prod
func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectID string) (*FetchAvailableModelsResponse, map[string]any, error) { func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectID string) (*FetchAvailableModelsResponse, map[string]any, error) {
reqBody := FetchAvailableModelsRequest{Project: projectID} reqBody := FetchAvailableModelsRequest{Project: projectID}
bodyBytes, err := json.Marshal(reqBody) bodyBytes, err := json.Marshal(reqBody)
...@@ -346,10 +412,19 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI ...@@ -346,10 +412,19 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
return nil, nil, fmt.Errorf("序列化请求失败: %w", err) return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
} }
apiURL := BaseURL + "/v1internal:fetchAvailableModels" // 获取可用的 URL 列表
availableURLs := DefaultURLAvailability.GetAvailableURLs()
if len(availableURLs) == 0 {
availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有
}
var lastErr error
for urlIdx, baseURL := range availableURLs {
apiURL := baseURL + "/v1internal:fetchAvailableModels"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes))) req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes)))
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("创建请求失败: %w", err) lastErr = fmt.Errorf("创建请求失败: %w", err)
continue
} }
req.Header.Set("Authorization", "Bearer "+accessToken) req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
...@@ -357,15 +432,28 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI ...@@ -357,15 +432,28 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
resp, err := c.httpClient.Do(req) resp, err := c.httpClient.Do(req)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("fetchAvailableModels 请求失败: %w", err) lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity] fetchAvailableModels URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
continue
}
return nil, nil, lastErr
} }
defer func() { _ = resp.Body.Close() }()
respBodyBytes, err := io.ReadAll(resp.Body) respBodyBytes, err := io.ReadAll(resp.Body)
_ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("读取响应失败: %w", err) return nil, nil, fmt.Errorf("读取响应失败: %w", err)
} }
// 检查是否需要 URL 降级
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity] fetchAvailableModels URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
continue
}
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes)) return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
} }
...@@ -380,4 +468,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI ...@@ -380,4 +468,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
_ = json.Unmarshal(respBodyBytes, &rawResp) _ = json.Unmarshal(respBodyBytes, &rawResp)
return &modelsResp, rawResp, nil return &modelsResp, rawResp, nil
}
return nil, nil, lastErr
} }
...@@ -32,17 +32,79 @@ const ( ...@@ -32,17 +32,79 @@ const (
"https://www.googleapis.com/auth/cclog " + "https://www.googleapis.com/auth/cclog " +
"https://www.googleapis.com/auth/experimentsandconfigs" "https://www.googleapis.com/auth/experimentsandconfigs"
// API 端点
// 优先使用 sandbox daily URL,配额更宽松
BaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
// User-Agent(模拟官方客户端) // User-Agent(模拟官方客户端)
UserAgent = "antigravity/1.104.0 darwin/arm64" UserAgent = "antigravity/1.104.0 darwin/arm64"
// Session 过期时间 // Session 过期时间
SessionTTL = 30 * time.Minute SessionTTL = 30 * time.Minute
// URL 可用性 TTL(不可用 URL 的恢复时间)
URLAvailabilityTTL = 5 * time.Minute
) )
// BaseURLs 定义 Antigravity API 端点,按优先级排序
// fallback 顺序: sandbox → daily → prod
var BaseURLs = []string{
"https://daily-cloudcode-pa.sandbox.googleapis.com", // sandbox
"https://daily-cloudcode-pa.googleapis.com", // daily
"https://cloudcode-pa.googleapis.com", // prod
}
// BaseURL 默认 URL(保持向后兼容)
var BaseURL = BaseURLs[0]
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复)
type URLAvailability struct {
mu sync.RWMutex
unavailable map[string]time.Time // URL -> 恢复时间
ttl time.Duration
}
// DefaultURLAvailability 全局 URL 可用性管理器
var DefaultURLAvailability = NewURLAvailability(URLAvailabilityTTL)
// NewURLAvailability 创建 URL 可用性管理器
func NewURLAvailability(ttl time.Duration) *URLAvailability {
return &URLAvailability{
unavailable: make(map[string]time.Time),
ttl: ttl,
}
}
// MarkUnavailable 标记 URL 临时不可用
func (u *URLAvailability) MarkUnavailable(url string) {
u.mu.Lock()
defer u.mu.Unlock()
u.unavailable[url] = time.Now().Add(u.ttl)
}
// IsAvailable 检查 URL 是否可用
func (u *URLAvailability) IsAvailable(url string) bool {
u.mu.RLock()
defer u.mu.RUnlock()
expiry, exists := u.unavailable[url]
if !exists {
return true
}
return time.Now().After(expiry)
}
// GetAvailableURLs 返回可用的 URL 列表(保持优先级顺序)
func (u *URLAvailability) GetAvailableURLs() []string {
u.mu.RLock()
defer u.mu.RUnlock()
now := time.Now()
result := make([]string, 0, len(BaseURLs))
for _, url := range BaseURLs {
expiry, exists := u.unavailable[url]
if !exists || now.After(expiry) {
result = append(result, url)
}
}
return result
}
// OAuthSession 保存 OAuth 授权流程的临时状态 // OAuthSession 保存 OAuth 授权流程的临时状态
type OAuthSession struct { type OAuthSession struct {
State string `json:"state"` State string `json:"state"`
......
...@@ -7,4 +7,6 @@ type Key string ...@@ -7,4 +7,6 @@ type Key string
const ( const (
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置 // ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
ForcePlatform Key = "ctx_force_platform" ForcePlatform Key = "ctx_force_platform"
// IsClaudeCodeClient 是否为 Claude Code 客户端,由中间件设置
IsClaudeCodeClient Key = "ctx_is_claude_code_client"
) )
...@@ -27,10 +27,9 @@ const ( ...@@ -27,10 +27,9 @@ const (
// https://www.googleapis.com/auth/generative-language.retriever (often with cloud-platform). // https://www.googleapis.com/auth/generative-language.retriever (often with cloud-platform).
DefaultAIStudioScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever" DefaultAIStudioScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever"
// DefaultScopes for Google One (personal Google accounts with Gemini access) // DefaultGoogleOneScopes (DEPRECATED, no longer used)
// Only used when a custom OAuth client is configured. When using the built-in Gemini CLI client, // Google One now always uses the built-in Gemini CLI client with DefaultCodeAssistScopes.
// Google One uses DefaultCodeAssistScopes (same as code_assist) because the built-in client // This constant is kept for backward compatibility but is not actively used.
// cannot request restricted scopes like generative-language.retriever or drive.readonly.
DefaultGoogleOneScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile" DefaultGoogleOneScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile"
// GeminiCLIRedirectURI is the redirect URI used by Gemini CLI for Code Assist OAuth. // GeminiCLIRedirectURI is the redirect URI used by Gemini CLI for Code Assist OAuth.
......
...@@ -185,13 +185,9 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error ...@@ -185,13 +185,9 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error
effective.Scopes = DefaultAIStudioScopes effective.Scopes = DefaultAIStudioScopes
} }
case "google_one": case "google_one":
// Google One uses built-in Gemini CLI client (same as code_assist) // Google One always uses built-in Gemini CLI client (same as code_assist)
// Built-in client can't request restricted scopes like generative-language.retriever // Built-in client can't request restricted scopes like generative-language.retriever or drive.readonly
if isBuiltinClient {
effective.Scopes = DefaultCodeAssistScopes effective.Scopes = DefaultCodeAssistScopes
} else {
effective.Scopes = DefaultGoogleOneScopes
}
default: default:
// Default to Code Assist scopes // Default to Code Assist scopes
effective.Scopes = DefaultCodeAssistScopes effective.Scopes = DefaultCodeAssistScopes
......
...@@ -23,14 +23,14 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { ...@@ -23,14 +23,14 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
wantErr: false, wantErr: false,
}, },
{ {
name: "Google One with custom client", name: "Google One always uses built-in client (even if custom credentials passed)",
input: OAuthConfig{ input: OAuthConfig{
ClientID: "custom-client-id", ClientID: "custom-client-id",
ClientSecret: "custom-client-secret", ClientSecret: "custom-client-secret",
}, },
oauthType: "google_one", oauthType: "google_one",
wantClientID: "custom-client-id", wantClientID: "custom-client-id",
wantScopes: DefaultGoogleOneScopes, wantScopes: DefaultCodeAssistScopes, // Uses code assist scopes even with custom client
wantErr: false, wantErr: false,
}, },
{ {
......
...@@ -886,6 +886,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates ...@@ -886,6 +886,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
args = append(args, *updates.Status) args = append(args, *updates.Status)
idx++ idx++
} }
if updates.Schedulable != nil {
setClauses = append(setClauses, "schedulable = $"+itoa(idx))
args = append(args, *updates.Schedulable)
idx++
}
// JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。 // JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。
if len(updates.Credentials) > 0 { if len(updates.Credentials) > 0 {
payload, err := json.Marshal(updates.Credentials) payload, err := json.Marshal(updates.Credentials)
......
...@@ -325,6 +325,8 @@ func groupEntityToService(g *dbent.Group) *service.Group { ...@@ -325,6 +325,8 @@ func groupEntityToService(g *dbent.Group) *service.Group {
ImagePrice2K: g.ImagePrice2k, ImagePrice2K: g.ImagePrice2k,
ImagePrice4K: g.ImagePrice4k, ImagePrice4K: g.ImagePrice4k,
DefaultValidityDays: g.DefaultValidityDays, DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
CreatedAt: g.CreatedAt, CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt, UpdatedAt: g.UpdatedAt,
} }
......
...@@ -2,6 +2,7 @@ package repository ...@@ -2,6 +2,7 @@ package repository
import ( import (
"context" "context"
"fmt"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -18,17 +19,23 @@ func NewGatewayCache(rdb *redis.Client) service.GatewayCache { ...@@ -18,17 +19,23 @@ func NewGatewayCache(rdb *redis.Client) service.GatewayCache {
return &gatewayCache{rdb: rdb} return &gatewayCache{rdb: rdb}
} }
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) { // buildSessionKey 构建 session key,包含 groupID 实现分组隔离
key := stickySessionPrefix + sessionHash // 格式: sticky_session:{groupID}:{sessionHash}
func buildSessionKey(groupID int64, sessionHash string) string {
return fmt.Sprintf("%s%d:%s", stickySessionPrefix, groupID, sessionHash)
}
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Get(ctx, key).Int64() return c.rdb.Get(ctx, key).Int64()
} }
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error { func (c *gatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
key := stickySessionPrefix + sessionHash key := buildSessionKey(groupID, sessionHash)
return c.rdb.Set(ctx, key, accountID, ttl).Err() return c.rdb.Set(ctx, key, accountID, ttl).Err()
} }
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error { func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
key := stickySessionPrefix + sessionHash key := buildSessionKey(groupID, sessionHash)
return c.rdb.Expire(ctx, key, ttl).Err() return c.rdb.Expire(ctx, key, ttl).Err()
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment