"...internal/handler/git@web.lueluesay.top:chenxi/sub2api.git" did not exist on "62e80c602d0fdcd90aeb404a14ae30d66eaaadd1"
Commit 0b9c4ae6 authored by shaw's avatar shaw
Browse files

fix: 修复claude setup token授权效期短的问题

parent 0d5a8a95
...@@ -145,7 +145,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe ...@@ -145,7 +145,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
return fullCode, nil return fullCode, nil
} }
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) { func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
client := s.clientFactory(proxyURL) client := s.clientFactory(proxyURL)
// Parse code which may contain state in format "authCode#state" // Parse code which may contain state in format "authCode#state"
...@@ -168,6 +168,11 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod ...@@ -168,6 +168,11 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
reqBody["state"] = codeState reqBody["state"] = codeState
} }
// Setup token requires longer expiration (1 year)
if isSetupToken {
reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds
}
reqBodyJSON, _ := json.Marshal(reqBody) reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL) log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON)) log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
......
...@@ -191,12 +191,13 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() { ...@@ -191,12 +191,13 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() { func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
tests := []struct { tests := []struct {
name string name string
handler http.HandlerFunc handler http.HandlerFunc
code string code string
wantErr bool isSetupToken bool
wantResp *oauth.TokenResponse wantErr bool
validate func(captured requestCapture) wantResp *oauth.TokenResponse
validate func(captured requestCapture)
}{ }{
{ {
name: "sends_state_when_embedded", name: "sends_state_when_embedded",
...@@ -210,7 +211,8 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() { ...@@ -210,7 +211,8 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
Scope: "s", Scope: "s",
}) })
}, },
code: "AUTH#STATE2", code: "AUTH#STATE2",
isSetupToken: false,
wantResp: &oauth.TokenResponse{ wantResp: &oauth.TokenResponse{
AccessToken: "at", AccessToken: "at",
RefreshToken: "rt", RefreshToken: "rt",
...@@ -223,6 +225,29 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() { ...@@ -223,6 +225,29 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"]) require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"]) require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
require.Equal(s.T(), "ver", captured.bodyJSON["code_verifier"]) require.Equal(s.T(), "ver", captured.bodyJSON["code_verifier"])
// Regular OAuth should not include expires_in
require.Nil(s.T(), captured.bodyJSON["expires_in"], "regular OAuth should not include expires_in")
},
},
{
name: "setup_token_includes_expires_in",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
AccessToken: "at",
TokenType: "bearer",
ExpiresIn: 31536000,
})
},
code: "AUTH",
isSetupToken: true,
wantResp: &oauth.TokenResponse{
AccessToken: "at",
},
validate: func(captured requestCapture) {
// Setup token should include expires_in with 1 year value
require.Equal(s.T(), float64(31536000), captured.bodyJSON["expires_in"],
"setup token should include expires_in: 31536000")
}, },
}, },
{ {
...@@ -231,8 +256,9 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() { ...@@ -231,8 +256,9 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte("bad request")) _, _ = w.Write([]byte("bad request"))
}, },
code: "AUTH", code: "AUTH",
wantErr: true, isSetupToken: false,
wantErr: true,
}, },
} }
...@@ -254,7 +280,7 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() { ...@@ -254,7 +280,7 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
s.client = client s.client = client
s.client.tokenURL = s.srv.URL s.client.tokenURL = s.srv.URL
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "") resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken)
if tt.wantErr { if tt.wantErr {
require.Error(s.T(), err) require.Error(s.T(), err)
......
...@@ -20,7 +20,7 @@ type OpenAIOAuthClient interface { ...@@ -20,7 +20,7 @@ type OpenAIOAuthClient interface {
type ClaudeOAuthClient interface { type ClaudeOAuthClient interface {
GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error)
GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error)
ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error)
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error)
} }
...@@ -142,8 +142,11 @@ func (s *OAuthService) ExchangeCode(ctx context.Context, input *ExchangeCodeInpu ...@@ -142,8 +142,11 @@ func (s *OAuthService) ExchangeCode(ctx context.Context, input *ExchangeCodeInpu
} }
} }
// Determine if this is a setup token (scope is inference only)
isSetupToken := session.Scope == oauth.ScopeInference
// Exchange code for token // Exchange code for token
tokenInfo, err := s.exchangeCodeForToken(ctx, input.Code, session.CodeVerifier, session.State, proxyURL) tokenInfo, err := s.exchangeCodeForToken(ctx, input.Code, session.CodeVerifier, session.State, proxyURL, isSetupToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -172,10 +175,12 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) ( ...@@ -172,10 +175,12 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) (
} }
} }
// Determine scope // Determine scope and if this is a setup token
scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference) scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference)
isSetupToken := false
if input.Scope == "inference" { if input.Scope == "inference" {
scope = oauth.ScopeInference scope = oauth.ScopeInference
isSetupToken = true
} }
// Step 1: Get organization info using sessionKey // Step 1: Get organization info using sessionKey
...@@ -203,7 +208,7 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) ( ...@@ -203,7 +208,7 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) (
} }
// Step 4: Exchange code for token // Step 4: Exchange code for token
tokenInfo, err := s.exchangeCodeForToken(ctx, authCode, codeVerifier, state, proxyURL) tokenInfo, err := s.exchangeCodeForToken(ctx, authCode, codeVerifier, state, proxyURL, isSetupToken)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to exchange code: %w", err) return nil, fmt.Errorf("failed to exchange code: %w", err)
} }
...@@ -228,8 +233,8 @@ func (s *OAuthService) getAuthorizationCode(ctx context.Context, sessionKey, org ...@@ -228,8 +233,8 @@ func (s *OAuthService) getAuthorizationCode(ctx context.Context, sessionKey, org
} }
// exchangeCodeForToken exchanges authorization code for tokens // exchangeCodeForToken exchanges authorization code for tokens
func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*TokenInfo, error) { func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*TokenInfo, error) {
tokenResp, err := s.oauthClient.ExchangeCodeForToken(ctx, code, codeVerifier, state, proxyURL) tokenResp, err := s.oauthClient.ExchangeCodeForToken(ctx, code, codeVerifier, state, proxyURL, isSetupToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment