Commit 5013febb authored by Jeromy Johnson's avatar Jeromy Johnson Committed by GitHub
Browse files

Merge pull request #89 from libp2p/feat/protocol-prefs

Feat/protocol prefs
parents d8468400 0b90707e
......@@ -5,7 +5,7 @@ os:
language: go
go:
- 1.5.2
- 1.7
env:
- GO15VENDOREXPERIMENT=1
......
......@@ -97,7 +97,7 @@ func main() {
log.Println("opening stream...")
// make a new stream from host B to host A
// it should be handled on host A by the handler we set
s, err := ha.NewStream(context.Background(), "/hello/1.0.0", a.ID())
s, err := ha.NewStream(context.Background(), a.ID(), "/hello/1.0.0")
if err != nil {
log.Fatalln(err)
}
......
......@@ -111,10 +111,10 @@ func (h *BasicHost) newStreamHandler(s inet.Stream) {
}
return
}
s.SetProtocol(protocol.ID(protoID))
logStream := mstream.WrapStream(s, protocol.ID(protoID), h.bwc)
logStream := mstream.WrapStream(s, h.bwc)
s.SetProtocol(protoID)
go handle(protoID, logStream)
}
......@@ -150,7 +150,7 @@ func (h *BasicHost) IDService() *identify.IDService {
func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler inet.StreamHandler) {
h.Mux().AddHandler(string(pid), func(p string, rwc io.ReadWriteCloser) error {
is := rwc.(inet.Stream)
is.SetProtocol(p)
is.SetProtocol(protocol.ID(p))
handler(is)
return nil
})
......@@ -161,7 +161,7 @@ func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler inet.StreamHandler
func (h *BasicHost) SetStreamHandlerMatch(pid protocol.ID, m func(string) bool, handler inet.StreamHandler) {
h.Mux().AddHandlerWithFunc(string(pid), m, func(p string, rwc io.ReadWriteCloser) error {
is := rwc.(inet.Stream)
is.SetProtocol(p)
is.SetProtocol(protocol.ID(p))
handler(is)
return nil
})
......@@ -176,13 +176,69 @@ func (h *BasicHost) RemoveStreamHandler(pid protocol.ID) {
// header with given protocol.ID. If there is no connection to p, attempts
// to create one. If ProtocolID is "", writes no header.
// (Threadsafe)
func (h *BasicHost) NewStream(ctx context.Context, pid protocol.ID, p peer.ID) (inet.Stream, error) {
func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (inet.Stream, error) {
pref, err := h.preferredProtocol(p, pids)
if err != nil {
return nil, err
}
if pref != "" {
return h.newStream(ctx, p, pref)
}
var protoStrs []string
for _, pid := range pids {
protoStrs = append(protoStrs, string(pid))
}
s, err := h.Network().NewStream(ctx, p)
if err != nil {
return nil, err
}
logStream := mstream.WrapStream(s, pid, h.bwc)
selected, err := msmux.SelectOneOf(protoStrs, s)
if err != nil {
s.Close()
return nil, err
}
selpid := protocol.ID(selected)
s.SetProtocol(selpid)
h.Peerstore().AddProtocols(p, selected)
return mstream.WrapStream(s, h.bwc), nil
}
func pidsToStrings(pids []protocol.ID) []string {
out := make([]string, len(pids))
for i, p := range pids {
out[i] = string(p)
}
return out
}
func (h *BasicHost) preferredProtocol(p peer.ID, pids []protocol.ID) (protocol.ID, error) {
pidstrs := pidsToStrings(pids)
supported, err := h.Peerstore().SupportsProtocols(p, pidstrs...)
if err != nil {
return "", err
}
var out protocol.ID
if len(supported) > 0 {
out = protocol.ID(supported[0])
}
return out, nil
}
func (h *BasicHost) newStream(ctx context.Context, p peer.ID, pid protocol.ID) (inet.Stream, error) {
s, err := h.Network().NewStream(ctx, p)
if err != nil {
return nil, err
}
s.SetProtocol(pid)
logStream := mstream.WrapStream(s, h.bwc)
lzcon := msmux.NewMSSelect(logStream, string(pid))
return &streamWrapper{
......
......@@ -4,7 +4,9 @@ import (
"bytes"
"io"
"testing"
"time"
host "github.com/libp2p/go-libp2p/p2p/host"
inet "github.com/libp2p/go-libp2p/p2p/net"
protocol "github.com/libp2p/go-libp2p/p2p/protocol"
testutil "github.com/libp2p/go-libp2p/p2p/test/util"
......@@ -32,7 +34,7 @@ func TestHostSimple(t *testing.T) {
io.Copy(w, s) // mirror everything
})
s, err := h1.NewStream(ctx, protocol.TestingID, h2pi.ID)
s, err := h1.NewStream(ctx, h2pi.ID, protocol.TestingID)
if err != nil {
t.Fatal(err)
}
......@@ -61,3 +63,182 @@ func TestHostSimple(t *testing.T) {
t.Fatal("buf1 != buf3 -- %x != %x", buf1, buf3)
}
}
func getHostPair(ctx context.Context, t *testing.T) (host.Host, host.Host) {
h1 := testutil.GenHostSwarm(t, ctx)
h2 := testutil.GenHostSwarm(t, ctx)
h2pi := h2.Peerstore().PeerInfo(h2.ID())
if err := h1.Connect(ctx, h2pi); err != nil {
t.Fatal(err)
}
return h1, h2
}
func assertWait(t *testing.T, c chan protocol.ID, exp protocol.ID) {
select {
case proto := <-c:
if proto != exp {
t.Fatal("should have connected on ", exp)
}
case <-time.After(time.Second * 5):
t.Fatal("timeout waiting for stream")
}
}
func TestHostProtoPreference(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
h1, h2 := getHostPair(ctx, t)
defer h1.Close()
defer h2.Close()
protoOld := protocol.ID("/testing")
protoNew := protocol.ID("/testing/1.1.0")
protoMinor := protocol.ID("/testing/1.2.0")
connectedOn := make(chan protocol.ID, 16)
handler := func(s inet.Stream) {
connectedOn <- s.Protocol()
s.Close()
}
h1.SetStreamHandler(protoOld, handler)
s, err := h2.NewStream(ctx, h1.ID(), protoMinor, protoNew, protoOld)
if err != nil {
t.Fatal(err)
}
assertWait(t, connectedOn, protoOld)
s.Close()
mfunc, err := host.MultistreamSemverMatcher(protoMinor)
if err != nil {
t.Fatal(err)
}
h1.SetStreamHandlerMatch(protoMinor, mfunc, handler)
// remembered preference will be chosen first, even when the other side newly supports it
s2, err := h2.NewStream(ctx, h1.ID(), protoMinor, protoNew, protoOld)
if err != nil {
t.Fatal(err)
}
// required to force 'lazy' handshake
_, err = s2.Write([]byte("hello"))
if err != nil {
t.Fatal(err)
}
assertWait(t, connectedOn, protoOld)
s2.Close()
s3, err := h2.NewStream(ctx, h1.ID(), protoMinor)
if err != nil {
t.Fatal(err)
}
assertWait(t, connectedOn, protoMinor)
s3.Close()
}
func TestHostProtoMismatch(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
h1, h2 := getHostPair(ctx, t)
defer h1.Close()
defer h2.Close()
h1.SetStreamHandler("/super", func(s inet.Stream) {
t.Error("shouldnt get here")
s.Close()
})
_, err := h2.NewStream(ctx, h1.ID(), "/foo", "/bar", "/baz/1.0.0")
if err == nil {
t.Fatal("expected new stream to fail")
}
}
func TestHostProtoPreknowledge(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
h1 := testutil.GenHostSwarm(t, ctx)
h2 := testutil.GenHostSwarm(t, ctx)
conn := make(chan protocol.ID, 16)
handler := func(s inet.Stream) {
conn <- s.Protocol()
s.Close()
}
h1.SetStreamHandler("/super", handler)
h2pi := h2.Peerstore().PeerInfo(h2.ID())
if err := h1.Connect(ctx, h2pi); err != nil {
t.Fatal(err)
}
defer h1.Close()
defer h2.Close()
// wait for identify handshake to finish completely
time.Sleep(time.Millisecond * 20)
h1.SetStreamHandler("/foo", handler)
s, err := h2.NewStream(ctx, h1.ID(), "/foo", "/bar", "/super")
if err != nil {
t.Fatal(err)
}
select {
case p := <-conn:
t.Fatal("shouldnt have gotten connection yet, we should have a lazy stream: ", p)
case <-time.After(time.Millisecond * 50):
}
_, err = s.Read(nil)
if err != nil {
t.Fatal(err)
}
assertWait(t, conn, "/super")
s.Close()
}
func TestNewDialOld(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
h1, h2 := getHostPair(ctx, t)
defer h1.Close()
defer h2.Close()
connectedOn := make(chan protocol.ID, 16)
h1.SetStreamHandler("/testing", func(s inet.Stream) {
connectedOn <- s.Protocol()
s.Close()
})
s, err := h2.NewStream(ctx, h1.ID(), "/testing/1.0.0", "/testing")
if err != nil {
t.Fatal(err)
}
assertWait(t, connectedOn, "/testing")
if s.Protocol() != "/testing" {
t.Fatal("shoould have gotten /testing")
}
s.Close()
}
......@@ -60,7 +60,7 @@ type Host interface {
// header with given protocol.ID. If there is no connection to p, attempts
// to create one. If ProtocolID is "", writes no header.
// (Threadsafe)
NewStream(ctx context.Context, pid protocol.ID, p peer.ID) (inet.Stream, error)
NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (inet.Stream, error)
// Close shuts down the host, its Network, and services.
Close() error
......
package host
import (
"github.com/libp2p/go-libp2p/p2p/protocol"
"strings"
semver "github.com/coreos/go-semver/semver"
)
func MultistreamSemverMatcher(base string) (func(string) bool, error) {
parts := strings.Split(base, "/")
func MultistreamSemverMatcher(base protocol.ID) (func(string) bool, error) {
parts := strings.Split(string(base), "/")
vers, err := semver.NewVersion(parts[len(parts)-1])
if err != nil {
return nil, err
......
......@@ -118,8 +118,8 @@ func (rh *RoutedHost) RemoveStreamHandler(pid protocol.ID) {
rh.host.RemoveStreamHandler(pid)
}
func (rh *RoutedHost) NewStream(ctx context.Context, pid protocol.ID, p peer.ID) (inet.Stream, error) {
return rh.host.NewStream(ctx, pid, p)
func (rh *RoutedHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (inet.Stream, error) {
return rh.host.NewStream(ctx, p, pids...)
}
func (rh *RoutedHost) Close() error {
// no need to close IpfsRouting. we dont own it.
......
......@@ -19,18 +19,18 @@ type meteredStream struct {
mesRecv metrics.StreamMeterCallback
}
func newMeteredStream(base inet.Stream, pid protocol.ID, p peer.ID, recvCB, sentCB metrics.StreamMeterCallback) inet.Stream {
func newMeteredStream(base inet.Stream, p peer.ID, recvCB, sentCB metrics.StreamMeterCallback) inet.Stream {
return &meteredStream{
Stream: base,
mesSent: sentCB,
mesRecv: recvCB,
protoKey: pid,
protoKey: base.Protocol(),
peerKey: p,
}
}
func WrapStream(base inet.Stream, pid protocol.ID, bwc metrics.Reporter) inet.Stream {
return newMeteredStream(base, pid, base.Conn().RemotePeer(), bwc.LogRecvMessageStream, bwc.LogSentMessageStream)
func WrapStream(base inet.Stream, bwc metrics.Reporter) inet.Stream {
return newMeteredStream(base, base.Conn().RemotePeer(), bwc.LogRecvMessageStream, bwc.LogSentMessageStream)
}
func (s *meteredStream) Read(b []byte) (int, error) {
......
......@@ -24,6 +24,10 @@ func (fs *FakeStream) Write(b []byte) (int, error) {
return len(b), nil
}
func (fs *FakeStream) Protocol() protocol.ID {
return "TEST"
}
func TestCallbacksWork(t *testing.T) {
fake := new(FakeStream)
......@@ -38,7 +42,7 @@ func TestCallbacksWork(t *testing.T) {
recv += n
}
ms := newMeteredStream(fake, protocol.ID("TEST"), peer.ID("PEER"), recvCB, sentCB)
ms := newMeteredStream(fake, peer.ID("PEER"), recvCB, sentCB)
toWrite := int64(100000)
toRead := int64(100000)
......
......@@ -8,6 +8,7 @@ import (
ma "github.com/jbenet/go-multiaddr"
"github.com/jbenet/goprocess"
conn "github.com/libp2p/go-libp2p/p2p/net/conn"
protocol "github.com/libp2p/go-libp2p/p2p/protocol"
context "golang.org/x/net/context"
)
......@@ -26,8 +27,8 @@ type Stream interface {
io.Writer
io.Closer
Protocol() string
SetProtocol(string)
Protocol() protocol.ID
SetProtocol(protocol.ID)
// Conn returns the connection this stream is part of.
Conn() Conn
......
......@@ -7,6 +7,7 @@ import (
process "github.com/jbenet/goprocess"
inet "github.com/libp2p/go-libp2p/p2p/net"
protocol "github.com/libp2p/go-libp2p/p2p/protocol"
)
// stream implements inet.Stream
......@@ -17,7 +18,7 @@ type stream struct {
toDeliver chan *transportObject
proc process.Process
protocol string
protocol protocol.ID
}
type transportObject struct {
......@@ -50,11 +51,11 @@ func (s *stream) Write(p []byte) (n int, err error) {
return len(p), nil
}
func (s *stream) Protocol() string {
func (s *stream) Protocol() protocol.ID {
return s.protocol
}
func (s *stream) SetProtocol(proto string) {
func (s *stream) SetProtocol(proto protocol.ID) {
s.protocol = proto
}
......
......@@ -298,7 +298,7 @@ func TestStreams(t *testing.T) {
h.SetStreamHandler(protocol.TestingID, handler)
}
s, err := hosts[0].NewStream(ctx, protocol.TestingID, hosts[1].ID())
s, err := hosts[0].NewStream(ctx, hosts[1].ID(), protocol.TestingID)
if err != nil {
t.Fatal(err)
}
......@@ -386,7 +386,7 @@ func TestStreamsStress(t *testing.T) {
defer wg.Done()
from := rand.Intn(len(hosts))
to := rand.Intn(len(hosts))
s, err := hosts[from].NewStream(ctx, protocol.TestingID, hosts[to].ID())
s, err := hosts[from].NewStream(ctx, hosts[to].ID(), protocol.TestingID)
if err != nil {
log.Debugf("%d (%s) %d (%s)", from, hosts[from], to, hosts[to])
panic(err)
......@@ -466,7 +466,7 @@ func TestAdding(t *testing.T) {
}
ctx := context.Background()
s, err := h1.NewStream(ctx, protocol.TestingID, p2)
s, err := h1.NewStream(ctx, p2, protocol.TestingID)
if err != nil {
t.Fatal(err)
}
......@@ -563,7 +563,7 @@ func TestLimitedStreams(t *testing.T) {
}
ctx := context.Background()
s, err := hosts[0].NewStream(ctx, protocol.TestingID, hosts[1].ID())
s, err := hosts[0].NewStream(ctx, hosts[1].ID(), protocol.TestingID)
if err != nil {
t.Fatal(err)
}
......
......@@ -68,6 +68,8 @@ func TestConnectednessCorrect(t *testing.T) {
t.Fatal(err)
}
time.Sleep(time.Millisecond * 50)
expectConnectedness(t, nets[2], nets[1], inet.NotConnected)
for _, n := range nets {
......
......@@ -2,6 +2,7 @@ package swarm
import (
inet "github.com/libp2p/go-libp2p/p2p/net"
protocol "github.com/libp2p/go-libp2p/p2p/protocol"
ps "github.com/jbenet/go-peerstream"
)
......@@ -10,7 +11,7 @@ import (
// our Conn and Swarm (instead of just the ps.Conn and ps.Swarm)
type Stream struct {
stream *ps.Stream
protocol string
protocol protocol.ID
}
// Stream returns the underlying peerstream.Stream
......@@ -44,11 +45,11 @@ func (s *Stream) Close() error {
return s.stream.Close()
}
func (s *Stream) Protocol() string {
func (s *Stream) Protocol() protocol.ID {
return s.protocol
}
func (s *Stream) SetProtocol(p string) {
func (s *Stream) SetProtocol(p protocol.ID) {
s.protocol = p
}
......
......@@ -86,8 +86,10 @@ func (ids *IDService) IdentifyConn(c inet.Conn) {
return
}
s.SetProtocol(ID)
bwc := ids.Host.GetBandwidthReporter()
s = mstream.WrapStream(s, ID, bwc)
s = mstream.WrapStream(s, bwc)
// ok give the response to our handler.
if err := msmux.SelectProtoOrFail(ID, s); err != nil {
......@@ -115,7 +117,7 @@ func (ids *IDService) RequestHandler(s inet.Stream) {
c := s.Conn()
bwc := ids.Host.GetBandwidthReporter()
s = mstream.WrapStream(s, ID, bwc)
s = mstream.WrapStream(s, bwc)
w := ggio.NewDelimitedWriter(s)
mes := pb.Identify{}
......@@ -173,7 +175,7 @@ func (ids *IDService) consumeMessage(mes *pb.Identify, c inet.Conn) {
p := c.RemotePeer()
// mes.Protocols
ids.Host.Peerstore().SetProtocols(p, mes.Protocols)
ids.Host.Peerstore().AddProtocols(p, mes.Protocols...)
// mes.ObservedAddr
ids.consumeObservedAddress(mes.GetObservedAddr(), c)
......
......@@ -68,7 +68,7 @@ func (p *PingService) PingHandler(s inet.Stream) {
}
func (ps *PingService) Ping(ctx context.Context, p peer.ID) (<-chan time.Duration, error) {
s, err := ps.Host.NewStream(ctx, ID, p)
s, err := ps.Host.NewStream(ctx, p, ID)
if err != nil {
return nil, err
}
......
......@@ -123,7 +123,7 @@ func (rs *RelayService) pipeStream(src, dst peer.ID, s inet.Stream) error {
// for now, can only open streams to directly connected peers.
// maybe we can do some routing later on.
func (rs *RelayService) openStreamToPeer(ctx context.Context, p peer.ID) (inet.Stream, error) {
return rs.host.NewStream(ctx, ID, p)
return rs.host.NewStream(ctx, p, ID)
}
func ReadHeader(r io.Reader) (src, dst peer.ID, err error) {
......
......@@ -49,7 +49,7 @@ func TestRelaySimple(t *testing.T) {
// ok, now we can try to relay n1--->n2--->n3.
log.Debug("open relay stream")
s, err := n1.NewStream(ctx, relay.ID, n2p)
s, err := n1.NewStream(ctx, n2p, relay.ID)
if err != nil {
t.Fatal(err)
}
......@@ -144,7 +144,7 @@ func TestRelayAcrossFour(t *testing.T) {
// ok, now we can try to relay n1--->n2--->n3--->n4--->n5
log.Debug("open relay stream")
s, err := n1.NewStream(ctx, relay.ID, n2p)
s, err := n1.NewStream(ctx, n2p, relay.ID)
if err != nil {
t.Fatal(err)
}
......@@ -244,7 +244,7 @@ func TestRelayStress(t *testing.T) {
// ok, now we can try to relay n1--->n2--->n3.
log.Debug("open relay stream")
s, err := n1.NewStream(ctx, relay.ID, n2p)
s, err := n1.NewStream(ctx, n2p, relay.ID)
if err != nil {
t.Fatal(err)
}
......
......@@ -83,7 +83,7 @@ a problem.
}()
for {
s, err = host.NewStream(context.Background(), protocol.TestingID, remote)
s, err = host.NewStream(context.Background(), remote, protocol.TestingID)
if err != nil {
return
}
......@@ -285,7 +285,7 @@ func TestStBackpressureStreamWrite(t *testing.T) {
}
// open a stream, from 2->1, this is our reader
s, err := h2.NewStream(context.Background(), protocol.TestingID, h1.ID())
s, err := h2.NewStream(context.Background(), h1.ID(), protocol.TestingID)
if err != nil {
t.Fatal(err)
}
......
......@@ -177,7 +177,7 @@ func SubtestConnSendDisc(t *testing.T, hosts []host.Host) {
for i := 0; i < numStreams; i++ {
h1 := hosts[i%len(hosts)]
h2 := hosts[(i+1)%len(hosts)]
s, err := h1.NewStream(context.Background(), protocol.TestingID, h2.ID())
s, err := h1.NewStream(context.Background(), h2.ID(), protocol.TestingID)
if err != nil {
t.Error(err)
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment