Unverified Commit d57690fe authored by Marten Seemann's avatar Marten Seemann Committed by GitHub
Browse files

Merge pull request #5 from marten-seemann/fix-listener-accept

close and discard a connection if the client's cert chain is invalid
parents fff1159c d558c774
......@@ -4,6 +4,7 @@ import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"io/ioutil"
......@@ -30,23 +31,26 @@ var _ = Describe("Connection", func() {
return priv
}
runServer := func() (<-chan ma.Multiaddr, <-chan tpt.Conn) {
serverTransport, err := NewTransport(serverKey)
Expect(err).ToNot(HaveOccurred())
runServer := func(tr tpt.Transport) (ma.Multiaddr, <-chan tpt.Conn) {
addrChan := make(chan ma.Multiaddr)
connChan := make(chan tpt.Conn)
go func() {
defer GinkgoRecover()
addr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic")
Expect(err).ToNot(HaveOccurred())
ln, err := serverTransport.Listen(addr)
ln, err := tr.Listen(addr)
Expect(err).ToNot(HaveOccurred())
addrChan <- ln.Multiaddr()
conn, err := ln.Accept()
Expect(err).ToNot(HaveOccurred())
connChan <- conn
}()
return addrChan, connChan
return <-addrChan, connChan
}
// modify the cert chain such that verificiation will fail
invalidateCertChain := func(tlsConf *tls.Config) {
tlsConf.Certificates[0].Certificate = [][]byte{tlsConf.Certificates[0].Certificate[0]}
}
BeforeEach(func() {
......@@ -60,10 +64,12 @@ var _ = Describe("Connection", func() {
})
It("handshakes", func() {
serverAddrChan, serverConnChan := runServer()
serverTransport, err := NewTransport(serverKey)
Expect(err).ToNot(HaveOccurred())
serverAddr, serverConnChan := runServer(serverTransport)
clientTransport, err := NewTransport(clientKey)
Expect(err).ToNot(HaveOccurred())
serverAddr := <-serverAddrChan
conn, err := clientTransport.Dial(context.Background(), serverAddr, serverID)
Expect(err).ToNot(HaveOccurred())
serverConn := <-serverConnChan
......@@ -78,10 +84,13 @@ var _ = Describe("Connection", func() {
})
It("opens and accepts streams", func() {
serverAddrChan, serverConnChan := runServer()
serverTransport, err := NewTransport(serverKey)
Expect(err).ToNot(HaveOccurred())
serverAddr, serverConnChan := runServer(serverTransport)
clientTransport, err := NewTransport(clientKey)
Expect(err).ToNot(HaveOccurred())
conn, err := clientTransport.Dial(context.Background(), <-serverAddrChan, serverID)
conn, err := clientTransport.Dial(context.Background(), serverAddr, serverID)
Expect(err).ToNot(HaveOccurred())
serverConn := <-serverConnChan
......@@ -101,14 +110,65 @@ var _ = Describe("Connection", func() {
thirdPartyID, err := peer.IDFromPrivateKey(createPeer())
Expect(err).ToNot(HaveOccurred())
serverAddrChan, serverConnChan := runServer()
serverTransport, err := NewTransport(serverKey)
Expect(err).ToNot(HaveOccurred())
serverAddr, serverConnChan := runServer(serverTransport)
clientTransport, err := NewTransport(clientKey)
Expect(err).ToNot(HaveOccurred())
serverAddr := <-serverAddrChan
// dial, but expect the wrong peer ID
_, err = clientTransport.Dial(context.Background(), serverAddr, thirdPartyID)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("TLS handshake error: bad certificate"))
Consistently(serverConnChan).ShouldNot(Receive())
})
It("fails if the client presents an invalid cert chain", func() {
serverTransport, err := NewTransport(serverKey)
Expect(err).ToNot(HaveOccurred())
serverAddr, serverConnChan := runServer(serverTransport)
clientTransport, err := NewTransport(clientKey)
invalidateCertChain(clientTransport.(*transport).tlsConf)
Expect(err).ToNot(HaveOccurred())
conn, err := clientTransport.Dial(context.Background(), serverAddr, serverID)
Expect(err).ToNot(HaveOccurred())
Eventually(func() bool { return conn.IsClosed() }).Should(BeTrue())
Consistently(serverConnChan).ShouldNot(Receive())
})
It("fails if the server presents an invalid cert chain", func() {
serverTransport, err := NewTransport(serverKey)
invalidateCertChain(serverTransport.(*transport).tlsConf)
Expect(err).ToNot(HaveOccurred())
serverAddr, serverConnChan := runServer(serverTransport)
clientTransport, err := NewTransport(clientKey)
Expect(err).ToNot(HaveOccurred())
_, err = clientTransport.Dial(context.Background(), serverAddr, serverID)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("TLS handshake error: bad certificate"))
Consistently(serverConnChan).ShouldNot(Receive())
})
It("keeps accepting connections after a failed connection attempt", func() {
serverTransport, err := NewTransport(serverKey)
Expect(err).ToNot(HaveOccurred())
serverAddr, serverConnChan := runServer(serverTransport)
// first dial with an invalid cert chain
clientTransport1, err := NewTransport(clientKey)
invalidateCertChain(clientTransport1.(*transport).tlsConf)
Expect(err).ToNot(HaveOccurred())
_, err = clientTransport1.Dial(context.Background(), serverAddr, serverID)
Expect(err).ToNot(HaveOccurred())
Consistently(serverConnChan).ShouldNot(Receive())
// then dial with a valid client
clientTransport2, err := NewTransport(clientKey)
Expect(err).ToNot(HaveOccurred())
_, err = clientTransport2.Dial(context.Background(), serverAddr, serverID)
Expect(err).ToNot(HaveOccurred())
Eventually(serverConnChan).Should(Receive())
})
})
......@@ -49,12 +49,22 @@ func newListener(addr ma.Multiaddr, transport tpt.Transport, localPeer peer.ID,
}
// Accept accepts new connections.
// TODO(#2): don't accept a connection if the client's peer verification fails
func (l *listener) Accept() (tpt.Conn, error) {
for {
sess, err := l.quicListener.Accept()
if err != nil {
return nil, err
}
conn, err := l.setupConn(sess)
if err != nil {
sess.Close(err)
continue
}
return conn, nil
}
}
func (l *listener) setupConn(sess quic.Session) (tpt.Conn, error) {
remotePubKey, err := getRemotePubKey(sess.ConnectionState().PeerCertificates)
if err != nil {
return nil, err
......
......@@ -57,6 +57,9 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp
}
var remotePubKey ic.PubKey
tlsConf := t.tlsConf.Clone()
// We need to check the peer ID in the VerifyPeerCertificate callback.
// The tls.Config it is also used for listening, and we might also have concurrent dials.
// Clone it so we can check for the specific peer ID we're dialing here.
tlsConf.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
chain := make([]*x509.Certificate, len(rawCerts))
for i := 0; i < len(rawCerts); i++ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment