Unverified Commit c25d491a authored by Marten Seemann's avatar Marten Seemann Committed by GitHub
Browse files

Merge pull request #19 from libp2p/one-packet-conn

use a single packet conn for all outgoing connections
parents b9d0283f 1751a3b6
package libp2pquic package libp2pquic
import ( import (
"net"
ic "github.com/libp2p/go-libp2p-crypto" ic "github.com/libp2p/go-libp2p-crypto"
peer "github.com/libp2p/go-libp2p-peer" peer "github.com/libp2p/go-libp2p-peer"
tpt "github.com/libp2p/go-libp2p-transport" tpt "github.com/libp2p/go-libp2p-transport"
smux "github.com/libp2p/go-stream-muxer" smux "github.com/libp2p/go-stream-muxer"
quic "github.com/lucas-clemente/quic-go" quic "github.com/lucas-clemente/quic-go"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr-net"
) )
type conn struct { type conn struct {
...@@ -81,15 +78,3 @@ func (c *conn) RemoteMultiaddr() ma.Multiaddr { ...@@ -81,15 +78,3 @@ func (c *conn) RemoteMultiaddr() ma.Multiaddr {
func (c *conn) Transport() tpt.Transport { func (c *conn) Transport() tpt.Transport {
return c.transport return c.transport
} }
func quicMultiaddr(na net.Addr) (ma.Multiaddr, error) {
udpMA, err := manet.FromNetAddr(na)
if err != nil {
return nil, err
}
quicMA, err := ma.NewMultiaddr("/quic")
if err != nil {
return nil, err
}
return udpMA.Encapsulate(quicMA), nil
}
package libp2pquic package libp2pquic
import ( import (
"bytes"
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"io/ioutil" "io/ioutil"
"time"
ic "github.com/libp2p/go-libp2p-crypto" ic "github.com/libp2p/go-libp2p-crypto"
peer "github.com/libp2p/go-libp2p-peer" peer "github.com/libp2p/go-libp2p-peer"
...@@ -23,12 +25,14 @@ var _ = Describe("Connection", func() { ...@@ -23,12 +25,14 @@ var _ = Describe("Connection", func() {
serverID, clientID peer.ID serverID, clientID peer.ID
) )
createPeer := func() ic.PrivKey { createPeer := func() (peer.ID, ic.PrivKey) {
key, err := rsa.GenerateKey(rand.Reader, 1024) key, err := rsa.GenerateKey(rand.Reader, 1024)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
priv, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(key)) priv, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(key))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
return priv id, err := peer.IDFromPrivateKey(priv)
Expect(err).ToNot(HaveOccurred())
return id, priv
} }
runServer := func(tr tpt.Transport) (ma.Multiaddr, <-chan tpt.Conn) { runServer := func(tr tpt.Transport) (ma.Multiaddr, <-chan tpt.Conn) {
...@@ -54,13 +58,8 @@ var _ = Describe("Connection", func() { ...@@ -54,13 +58,8 @@ var _ = Describe("Connection", func() {
} }
BeforeEach(func() { BeforeEach(func() {
var err error serverID, serverKey = createPeer()
serverKey = createPeer() clientID, clientKey = createPeer()
serverID, err = peer.IDFromPrivateKey(serverKey)
Expect(err).ToNot(HaveOccurred())
clientKey = createPeer()
clientID, err = peer.IDFromPrivateKey(clientKey)
Expect(err).ToNot(HaveOccurred())
}) })
It("handshakes", func() { It("handshakes", func() {
...@@ -107,8 +106,7 @@ var _ = Describe("Connection", func() { ...@@ -107,8 +106,7 @@ var _ = Describe("Connection", func() {
}) })
It("fails if the peer ID doesn't match", func() { It("fails if the peer ID doesn't match", func() {
thirdPartyID, err := peer.IDFromPrivateKey(createPeer()) thirdPartyID, _ := createPeer()
Expect(err).ToNot(HaveOccurred())
serverTransport, err := NewTransport(serverKey) serverTransport, err := NewTransport(serverKey)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
...@@ -171,4 +169,57 @@ var _ = Describe("Connection", func() { ...@@ -171,4 +169,57 @@ var _ = Describe("Connection", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Eventually(serverConnChan).Should(Receive()) Eventually(serverConnChan).Should(Receive())
}) })
It("dials to two servers at the same time", func() {
serverID2, serverKey2 := createPeer()
serverTransport, err := NewTransport(serverKey)
Expect(err).ToNot(HaveOccurred())
serverAddr, serverConnChan := runServer(serverTransport)
serverTransport2, err := NewTransport(serverKey2)
Expect(err).ToNot(HaveOccurred())
serverAddr2, serverConnChan2 := runServer(serverTransport2)
data := bytes.Repeat([]byte{'a'}, 5*1<<20) // 5 MB
// wait for both servers to accept a connection
// then send some data
go func() {
for _, c := range []tpt.Conn{<-serverConnChan, <-serverConnChan2} {
go func(conn tpt.Conn) {
defer GinkgoRecover()
str, err := conn.OpenStream()
Expect(err).ToNot(HaveOccurred())
defer str.Close()
_, err = str.Write(data)
Expect(err).ToNot(HaveOccurred())
}(c)
}
}()
clientTransport, err := NewTransport(clientKey)
Expect(err).ToNot(HaveOccurred())
c1, err := clientTransport.Dial(context.Background(), serverAddr, serverID)
Expect(err).ToNot(HaveOccurred())
c2, err := clientTransport.Dial(context.Background(), serverAddr2, serverID2)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{}, 2)
// receive the data on both connections at the same time
for _, c := range []tpt.Conn{c1, c2} {
go func(conn tpt.Conn) {
defer GinkgoRecover()
str, err := conn.AcceptStream()
Expect(err).ToNot(HaveOccurred())
str.Close()
d, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(d).To(Equal(data))
conn.Close()
done <- struct{}{}
}(c)
}
Eventually(done, 5*time.Second).Should(Receive())
Eventually(done, 5*time.Second).Should(Receive())
})
}) })
...@@ -35,7 +35,7 @@ func newListener(addr ma.Multiaddr, transport tpt.Transport, localPeer peer.ID, ...@@ -35,7 +35,7 @@ func newListener(addr ma.Multiaddr, transport tpt.Transport, localPeer peer.ID,
if err != nil { if err != nil {
return nil, err return nil, err
} }
localMultiaddr, err := quicMultiaddr(ln.Addr()) localMultiaddr, err := toQuicMultiaddr(ln.Addr())
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -73,7 +73,7 @@ func (l *listener) setupConn(sess quic.Session) (tpt.Conn, error) { ...@@ -73,7 +73,7 @@ func (l *listener) setupConn(sess quic.Session) (tpt.Conn, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
remoteMultiaddr, err := quicMultiaddr(sess.RemoteAddr()) remoteMultiaddr, err := toQuicMultiaddr(sess.RemoteAddr())
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
package libp2pquic
import (
"net"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr-net"
)
var quicMA ma.Multiaddr
func init() {
var err error
quicMA, err = ma.NewMultiaddr("/quic")
if err != nil {
panic(err)
}
}
func toQuicMultiaddr(na net.Addr) (ma.Multiaddr, error) {
udpMA, err := manet.FromNetAddr(na)
if err != nil {
return nil, err
}
return udpMA.Encapsulate(quicMA), nil
}
func fromQuicMultiaddr(addr ma.Multiaddr) (net.Addr, error) {
return manet.ToNetAddr(addr.Decapsulate(quicMA))
}
package libp2pquic
import (
"net"
ma "github.com/multiformats/go-multiaddr"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("QUIC Multiaddr", func() {
It("converts a net.Addr to a QUIC Multiaddr", func() {
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 42), Port: 1337}
maddr, err := toQuicMultiaddr(addr)
Expect(err).ToNot(HaveOccurred())
Expect(maddr.String()).To(Equal("/ip4/192.168.0.42/udp/1337/quic"))
})
It("converts a QUIC Multiaddr to a net.Addr", func() {
maddr, err := ma.NewMultiaddr("/ip4/192.168.0.42/udp/1337/quic")
Expect(err).ToNot(HaveOccurred())
addr, err := fromQuicMultiaddr(maddr)
Expect(err).ToNot(HaveOccurred())
Expect(addr).To(BeAssignableToTypeOf(&net.UDPAddr{}))
udpAddr := addr.(*net.UDPAddr)
Expect(udpAddr.IP).To(Equal(net.IPv4(192, 168, 0, 42)))
Expect(udpAddr.Port).To(Equal(1337))
})
})
...@@ -31,6 +31,7 @@ type transport struct { ...@@ -31,6 +31,7 @@ type transport struct {
privKey ic.PrivKey privKey ic.PrivKey
localPeer peer.ID localPeer peer.ID
tlsConf *tls.Config tlsConf *tls.Config
pconn net.PacketConn
} }
var _ tpt.Transport = &transport{} var _ tpt.Transport = &transport{}
...@@ -45,10 +46,22 @@ func NewTransport(key ic.PrivKey) (tpt.Transport, error) { ...@@ -45,10 +46,22 @@ func NewTransport(key ic.PrivKey) (tpt.Transport, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
// create a packet conn for outgoing connections
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
return nil, err
}
return &transport{ return &transport{
privKey: key, privKey: key,
localPeer: localPeer, localPeer: localPeer,
tlsConf: tlsConf, tlsConf: tlsConf,
pconn: conn,
}, nil }, nil
} }
...@@ -58,6 +71,10 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp ...@@ -58,6 +71,10 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp
if err != nil { if err != nil {
return nil, err return nil, err
} }
addr, err := fromQuicMultiaddr(raddr)
if err != nil {
return nil, err
}
var remotePubKey ic.PubKey var remotePubKey ic.PubKey
tlsConf := t.tlsConf.Clone() tlsConf := t.tlsConf.Clone()
// We need to check the peer ID in the VerifyPeerCertificate callback. // We need to check the peer ID in the VerifyPeerCertificate callback.
...@@ -82,11 +99,11 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp ...@@ -82,11 +99,11 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp
} }
return nil return nil
} }
sess, err := quic.DialAddrContext(ctx, host, tlsConf, quicConfig) sess, err := quic.DialContext(ctx, t.pconn, addr, host, tlsConf, quicConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
localMultiaddr, err := quicMultiaddr(sess.LocalAddr()) localMultiaddr, err := toQuicMultiaddr(sess.LocalAddr())
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment