diff --git a/examples/hosts/main.go b/examples/hosts/main.go index 256fdbff8f8d6336caffe3d074aee16e90725977..2fe0b8fe5bba3cb73ef0cbf6a523da5a1abd5ac7 100644 --- a/examples/hosts/main.go +++ b/examples/hosts/main.go @@ -97,7 +97,7 @@ func main() { log.Println("opening stream...") // make a new stream from host B to host A // it should be handled on host A by the handler we set - s, err := ha.NewStream(context.Background(), "/hello/1.0.0", a.ID()) + s, err := ha.NewStream(context.Background(), a.ID(), "/hello/1.0.0") if err != nil { log.Fatalln(err) } diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 68b3e08e013d96e69e082caaf4d2300c77548774..186094c630a4876e1df834a46575d152e87ee163 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -2,6 +2,7 @@ package basichost import ( "io" + "sync" peer "github.com/ipfs/go-libp2p-peer" pstore "github.com/ipfs/go-libp2p-peerstore" @@ -46,6 +47,9 @@ type BasicHost struct { relay *relay.RelayService natmgr *natManager + protoPrefs map[peer.ID]map[protocol.ID]struct{} + prefsLk sync.Mutex + proc goprocess.Process bwc metrics.Reporter @@ -54,9 +58,10 @@ type BasicHost struct { // New constructs and sets up a new *BasicHost with given Network func New(net inet.Network, opts ...interface{}) *BasicHost { h := &BasicHost{ - network: net, - mux: msmux.NewMultistreamMuxer(), - bwc: metrics.NewBandwidthCounter(), + network: net, + mux: msmux.NewMultistreamMuxer(), + bwc: metrics.NewBandwidthCounter(), + protoPrefs: make(map[peer.ID]map[protocol.ID]struct{}), } h.proc = goprocess.WithTeardown(func() error { @@ -176,7 +181,58 @@ func (h *BasicHost) RemoveStreamHandler(pid protocol.ID) { // header with given protocol.ID. If there is no connection to p, attempts // to create one. If ProtocolID is "", writes no header. // (Threadsafe) -func (h *BasicHost) NewStream(ctx context.Context, pid protocol.ID, p peer.ID) (inet.Stream, error) { +func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (inet.Stream, error) { + pref := h.preferredProtocol(p, pids) + if pref != "" { + return h.newStream(ctx, p, pref) + } + + var lastErr error + for _, pid := range pids { + s, err := h.newStream(ctx, p, pid) + if err == nil { + h.setPreferredProtocol(p, pid) + return s, nil + } + lastErr = err + log.Infof("NewStream to %s for %s failed: %s", p, pid, err) + } + + return nil, lastErr +} + +func (h *BasicHost) preferredProtocol(p peer.ID, pids []protocol.ID) protocol.ID { + h.prefsLk.Lock() + defer h.prefsLk.Unlock() + + prefs, ok := h.protoPrefs[p] + if !ok { + return "" + } + + for _, pid := range pids { + if _, ok := prefs[pid]; ok { + return pid + } + } + + return "" +} + +func (h *BasicHost) setPreferredProtocol(p peer.ID, proto protocol.ID) { + h.prefsLk.Lock() + defer h.prefsLk.Unlock() + + prefs, ok := h.protoPrefs[p] + if !ok { + prefs = make(map[protocol.ID]struct{}) + h.protoPrefs[p] = prefs + } + + prefs[proto] = struct{}{} +} + +func (h *BasicHost) newStream(ctx context.Context, p peer.ID, pid protocol.ID) (inet.Stream, error) { s, err := h.Network().NewStream(ctx, p) if err != nil { return nil, err diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index f1f2e248a3fa4629e7369461c3777f94614d39d6..a7c893ef88ceb31a6260c5dcf4f5f87d836c20bc 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -32,7 +32,7 @@ func TestHostSimple(t *testing.T) { io.Copy(w, s) // mirror everything }) - s, err := h1.NewStream(ctx, protocol.TestingID, h2pi.ID) + s, err := h1.NewStream(ctx, h2pi.ID, protocol.TestingID) if err != nil { t.Fatal(err) } diff --git a/p2p/host/host.go b/p2p/host/host.go index 65810e0334a86557e5b361a0de02433e8ead0f3b..c4da4f51b376047dec4ea00fcb38d123185bc69e 100644 --- a/p2p/host/host.go +++ b/p2p/host/host.go @@ -60,7 +60,7 @@ type Host interface { // header with given protocol.ID. If there is no connection to p, attempts // to create one. If ProtocolID is "", writes no header. // (Threadsafe) - NewStream(ctx context.Context, pid protocol.ID, p peer.ID) (inet.Stream, error) + NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (inet.Stream, error) // Close shuts down the host, its Network, and services. Close() error diff --git a/p2p/host/routed/routed.go b/p2p/host/routed/routed.go index 4d25afdd4d69d21cb65d4010cd54cdc4ebbf7b0a..928297df7f8399b5ab842274c9e676182723179f 100644 --- a/p2p/host/routed/routed.go +++ b/p2p/host/routed/routed.go @@ -118,8 +118,8 @@ func (rh *RoutedHost) RemoveStreamHandler(pid protocol.ID) { rh.host.RemoveStreamHandler(pid) } -func (rh *RoutedHost) NewStream(ctx context.Context, pid protocol.ID, p peer.ID) (inet.Stream, error) { - return rh.host.NewStream(ctx, pid, p) +func (rh *RoutedHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (inet.Stream, error) { + return rh.host.NewStream(ctx, p, pids...) } func (rh *RoutedHost) Close() error { // no need to close IpfsRouting. we dont own it. diff --git a/p2p/net/mock/mock_test.go b/p2p/net/mock/mock_test.go index 9c33b3bb6a1c3045e7321bb3e4c57468866cb14d..973104e255f2d8f83ac2052c64c7df55fa1cc3b6 100644 --- a/p2p/net/mock/mock_test.go +++ b/p2p/net/mock/mock_test.go @@ -298,7 +298,7 @@ func TestStreams(t *testing.T) { h.SetStreamHandler(protocol.TestingID, handler) } - s, err := hosts[0].NewStream(ctx, protocol.TestingID, hosts[1].ID()) + s, err := hosts[0].NewStream(ctx, hosts[1].ID(), protocol.TestingID) if err != nil { t.Fatal(err) } @@ -386,7 +386,7 @@ func TestStreamsStress(t *testing.T) { defer wg.Done() from := rand.Intn(len(hosts)) to := rand.Intn(len(hosts)) - s, err := hosts[from].NewStream(ctx, protocol.TestingID, hosts[to].ID()) + s, err := hosts[from].NewStream(ctx, hosts[to].ID(), protocol.TestingID) if err != nil { log.Debugf("%d (%s) %d (%s)", from, hosts[from], to, hosts[to]) panic(err) @@ -466,7 +466,7 @@ func TestAdding(t *testing.T) { } ctx := context.Background() - s, err := h1.NewStream(ctx, protocol.TestingID, p2) + s, err := h1.NewStream(ctx, p2, protocol.TestingID) if err != nil { t.Fatal(err) } @@ -563,7 +563,7 @@ func TestLimitedStreams(t *testing.T) { } ctx := context.Background() - s, err := hosts[0].NewStream(ctx, protocol.TestingID, hosts[1].ID()) + s, err := hosts[0].NewStream(ctx, hosts[1].ID(), protocol.TestingID) if err != nil { t.Fatal(err) } diff --git a/p2p/protocol/ping/ping.go b/p2p/protocol/ping/ping.go index e16ff6686e7de78c5403d2873fec359673b9b6e5..e4dc40a967905342f48882af609907cb6d9527b7 100644 --- a/p2p/protocol/ping/ping.go +++ b/p2p/protocol/ping/ping.go @@ -68,7 +68,7 @@ func (p *PingService) PingHandler(s inet.Stream) { } func (ps *PingService) Ping(ctx context.Context, p peer.ID) (<-chan time.Duration, error) { - s, err := ps.Host.NewStream(ctx, ID, p) + s, err := ps.Host.NewStream(ctx, p, ID) if err != nil { return nil, err } diff --git a/p2p/protocol/relay/relay.go b/p2p/protocol/relay/relay.go index 6fa3c52a552b3b1cd734c3de46c9d8269c1a3d75..885018dee6fefb1a213f95ed855ee8a165a0d40f 100644 --- a/p2p/protocol/relay/relay.go +++ b/p2p/protocol/relay/relay.go @@ -123,7 +123,7 @@ func (rs *RelayService) pipeStream(src, dst peer.ID, s inet.Stream) error { // for now, can only open streams to directly connected peers. // maybe we can do some routing later on. func (rs *RelayService) openStreamToPeer(ctx context.Context, p peer.ID) (inet.Stream, error) { - return rs.host.NewStream(ctx, ID, p) + return rs.host.NewStream(ctx, p, ID) } func ReadHeader(r io.Reader) (src, dst peer.ID, err error) { diff --git a/p2p/protocol/relay/relay_test.go b/p2p/protocol/relay/relay_test.go index 858b525776e6c671e0725cf413e08d64edd647a8..e93491f793bf82ea94a9773bb02e54d676a98021 100644 --- a/p2p/protocol/relay/relay_test.go +++ b/p2p/protocol/relay/relay_test.go @@ -49,7 +49,7 @@ func TestRelaySimple(t *testing.T) { // ok, now we can try to relay n1--->n2--->n3. log.Debug("open relay stream") - s, err := n1.NewStream(ctx, relay.ID, n2p) + s, err := n1.NewStream(ctx, n2p, relay.ID) if err != nil { t.Fatal(err) } @@ -144,7 +144,7 @@ func TestRelayAcrossFour(t *testing.T) { // ok, now we can try to relay n1--->n2--->n3--->n4--->n5 log.Debug("open relay stream") - s, err := n1.NewStream(ctx, relay.ID, n2p) + s, err := n1.NewStream(ctx, n2p, relay.ID) if err != nil { t.Fatal(err) } @@ -244,7 +244,7 @@ func TestRelayStress(t *testing.T) { // ok, now we can try to relay n1--->n2--->n3. log.Debug("open relay stream") - s, err := n1.NewStream(ctx, relay.ID, n2p) + s, err := n1.NewStream(ctx, n2p, relay.ID) if err != nil { t.Fatal(err) } diff --git a/p2p/test/backpressure/backpressure_test.go b/p2p/test/backpressure/backpressure_test.go index 859692753c2696920d3fa3ed3e5fa05708d0fddc..aedc0e3b82928dfae9f33c667b82be68d023e1e8 100644 --- a/p2p/test/backpressure/backpressure_test.go +++ b/p2p/test/backpressure/backpressure_test.go @@ -83,7 +83,7 @@ a problem. }() for { - s, err = host.NewStream(context.Background(), protocol.TestingID, remote) + s, err = host.NewStream(context.Background(), remote, protocol.TestingID) if err != nil { return } @@ -285,7 +285,7 @@ func TestStBackpressureStreamWrite(t *testing.T) { } // open a stream, from 2->1, this is our reader - s, err := h2.NewStream(context.Background(), protocol.TestingID, h1.ID()) + s, err := h2.NewStream(context.Background(), h1.ID(), protocol.TestingID) if err != nil { t.Fatal(err) } diff --git a/p2p/test/reconnects/reconnect_test.go b/p2p/test/reconnects/reconnect_test.go index 49125a39740cf5d16669ed4f0470f662fbd6d243..aea77cd861b099e3f99f8d101b01f757dff35cce 100644 --- a/p2p/test/reconnects/reconnect_test.go +++ b/p2p/test/reconnects/reconnect_test.go @@ -177,7 +177,7 @@ func SubtestConnSendDisc(t *testing.T, hosts []host.Host) { for i := 0; i < numStreams; i++ { h1 := hosts[i%len(hosts)] h2 := hosts[(i+1)%len(hosts)] - s, err := h1.NewStream(context.Background(), protocol.TestingID, h2.ID()) + s, err := h1.NewStream(context.Background(), h2.ID(), protocol.TestingID) if err != nil { t.Error(err) }