Unverified Commit d57690fe authored by Marten Seemann's avatar Marten Seemann Committed by GitHub
Browse files

Merge pull request #5 from marten-seemann/fix-listener-accept

close and discard a connection if the client's cert chain is invalid
parents fff1159c d558c774
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/tls"
"crypto/x509" "crypto/x509"
"io/ioutil" "io/ioutil"
...@@ -30,23 +31,26 @@ var _ = Describe("Connection", func() { ...@@ -30,23 +31,26 @@ var _ = Describe("Connection", func() {
return priv return priv
} }
runServer := func() (<-chan ma.Multiaddr, <-chan tpt.Conn) { runServer := func(tr tpt.Transport) (ma.Multiaddr, <-chan tpt.Conn) {
serverTransport, err := NewTransport(serverKey)
Expect(err).ToNot(HaveOccurred())
addrChan := make(chan ma.Multiaddr) addrChan := make(chan ma.Multiaddr)
connChan := make(chan tpt.Conn) connChan := make(chan tpt.Conn)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
addr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") addr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic")
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
ln, err := serverTransport.Listen(addr) ln, err := tr.Listen(addr)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
addrChan <- ln.Multiaddr() addrChan <- ln.Multiaddr()
conn, err := ln.Accept() conn, err := ln.Accept()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
connChan <- conn connChan <- conn
}() }()
return addrChan, connChan return <-addrChan, connChan
}
// modify the cert chain such that verificiation will fail
invalidateCertChain := func(tlsConf *tls.Config) {
tlsConf.Certificates[0].Certificate = [][]byte{tlsConf.Certificates[0].Certificate[0]}
} }
BeforeEach(func() { BeforeEach(func() {
...@@ -60,10 +64,12 @@ var _ = Describe("Connection", func() { ...@@ -60,10 +64,12 @@ var _ = Describe("Connection", func() {
}) })
It("handshakes", func() { It("handshakes", func() {
serverAddrChan, serverConnChan := runServer() serverTransport, err := NewTransport(serverKey)
Expect(err).ToNot(HaveOccurred())
serverAddr, serverConnChan := runServer(serverTransport)
clientTransport, err := NewTransport(clientKey) clientTransport, err := NewTransport(clientKey)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
serverAddr := <-serverAddrChan
conn, err := clientTransport.Dial(context.Background(), serverAddr, serverID) conn, err := clientTransport.Dial(context.Background(), serverAddr, serverID)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
serverConn := <-serverConnChan serverConn := <-serverConnChan
...@@ -78,10 +84,13 @@ var _ = Describe("Connection", func() { ...@@ -78,10 +84,13 @@ var _ = Describe("Connection", func() {
}) })
It("opens and accepts streams", func() { It("opens and accepts streams", func() {
serverAddrChan, serverConnChan := runServer() serverTransport, err := NewTransport(serverKey)
Expect(err).ToNot(HaveOccurred())
serverAddr, serverConnChan := runServer(serverTransport)
clientTransport, err := NewTransport(clientKey) clientTransport, err := NewTransport(clientKey)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
conn, err := clientTransport.Dial(context.Background(), <-serverAddrChan, serverID) conn, err := clientTransport.Dial(context.Background(), serverAddr, serverID)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
serverConn := <-serverConnChan serverConn := <-serverConnChan
...@@ -101,14 +110,65 @@ var _ = Describe("Connection", func() { ...@@ -101,14 +110,65 @@ var _ = Describe("Connection", func() {
thirdPartyID, err := peer.IDFromPrivateKey(createPeer()) thirdPartyID, err := peer.IDFromPrivateKey(createPeer())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
serverAddrChan, serverConnChan := runServer() serverTransport, err := NewTransport(serverKey)
Expect(err).ToNot(HaveOccurred())
serverAddr, serverConnChan := runServer(serverTransport)
clientTransport, err := NewTransport(clientKey) clientTransport, err := NewTransport(clientKey)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
serverAddr := <-serverAddrChan
// dial, but expect the wrong peer ID // dial, but expect the wrong peer ID
_, err = clientTransport.Dial(context.Background(), serverAddr, thirdPartyID) _, err = clientTransport.Dial(context.Background(), serverAddr, thirdPartyID)
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("TLS handshake error: bad certificate")) Expect(err.Error()).To(ContainSubstring("TLS handshake error: bad certificate"))
Consistently(serverConnChan).ShouldNot(Receive()) Consistently(serverConnChan).ShouldNot(Receive())
}) })
It("fails if the client presents an invalid cert chain", func() {
serverTransport, err := NewTransport(serverKey)
Expect(err).ToNot(HaveOccurred())
serverAddr, serverConnChan := runServer(serverTransport)
clientTransport, err := NewTransport(clientKey)
invalidateCertChain(clientTransport.(*transport).tlsConf)
Expect(err).ToNot(HaveOccurred())
conn, err := clientTransport.Dial(context.Background(), serverAddr, serverID)
Expect(err).ToNot(HaveOccurred())
Eventually(func() bool { return conn.IsClosed() }).Should(BeTrue())
Consistently(serverConnChan).ShouldNot(Receive())
})
It("fails if the server presents an invalid cert chain", func() {
serverTransport, err := NewTransport(serverKey)
invalidateCertChain(serverTransport.(*transport).tlsConf)
Expect(err).ToNot(HaveOccurred())
serverAddr, serverConnChan := runServer(serverTransport)
clientTransport, err := NewTransport(clientKey)
Expect(err).ToNot(HaveOccurred())
_, err = clientTransport.Dial(context.Background(), serverAddr, serverID)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("TLS handshake error: bad certificate"))
Consistently(serverConnChan).ShouldNot(Receive())
})
It("keeps accepting connections after a failed connection attempt", func() {
serverTransport, err := NewTransport(serverKey)
Expect(err).ToNot(HaveOccurred())
serverAddr, serverConnChan := runServer(serverTransport)
// first dial with an invalid cert chain
clientTransport1, err := NewTransport(clientKey)
invalidateCertChain(clientTransport1.(*transport).tlsConf)
Expect(err).ToNot(HaveOccurred())
_, err = clientTransport1.Dial(context.Background(), serverAddr, serverID)
Expect(err).ToNot(HaveOccurred())
Consistently(serverConnChan).ShouldNot(Receive())
// then dial with a valid client
clientTransport2, err := NewTransport(clientKey)
Expect(err).ToNot(HaveOccurred())
_, err = clientTransport2.Dial(context.Background(), serverAddr, serverID)
Expect(err).ToNot(HaveOccurred())
Eventually(serverConnChan).Should(Receive())
})
}) })
...@@ -49,12 +49,22 @@ func newListener(addr ma.Multiaddr, transport tpt.Transport, localPeer peer.ID, ...@@ -49,12 +49,22 @@ func newListener(addr ma.Multiaddr, transport tpt.Transport, localPeer peer.ID,
} }
// Accept accepts new connections. // Accept accepts new connections.
// TODO(#2): don't accept a connection if the client's peer verification fails
func (l *listener) Accept() (tpt.Conn, error) { func (l *listener) Accept() (tpt.Conn, error) {
for {
sess, err := l.quicListener.Accept() sess, err := l.quicListener.Accept()
if err != nil { if err != nil {
return nil, err return nil, err
} }
conn, err := l.setupConn(sess)
if err != nil {
sess.Close(err)
continue
}
return conn, nil
}
}
func (l *listener) setupConn(sess quic.Session) (tpt.Conn, error) {
remotePubKey, err := getRemotePubKey(sess.ConnectionState().PeerCertificates) remotePubKey, err := getRemotePubKey(sess.ConnectionState().PeerCertificates)
if err != nil { if err != nil {
return nil, err return nil, err
......
...@@ -57,6 +57,9 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp ...@@ -57,6 +57,9 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp
} }
var remotePubKey ic.PubKey var remotePubKey ic.PubKey
tlsConf := t.tlsConf.Clone() tlsConf := t.tlsConf.Clone()
// We need to check the peer ID in the VerifyPeerCertificate callback.
// The tls.Config it is also used for listening, and we might also have concurrent dials.
// Clone it so we can check for the specific peer ID we're dialing here.
tlsConf.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { tlsConf.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
chain := make([]*x509.Certificate, len(rawCerts)) chain := make([]*x509.Certificate, len(rawCerts))
for i := 0; i < len(rawCerts); i++ { for i := 0; i < len(rawCerts); i++ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment