diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index 0c7dea29fefd06ebd022420c983266cb143c2bdc..caf2d8f7bcd0b12442fc70e2ed34ac1174d30afa 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -241,3 +241,62 @@ func TestNewDialOld(t *testing.T) { s.Close() } + +func TestProtoDowngrade(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + h1, h2 := getHostPair(ctx, t) + defer h1.Close() + defer h2.Close() + + connectedOn := make(chan protocol.ID, 16) + h1.SetStreamHandler("/testing/1.0.0", func(s inet.Stream) { + connectedOn <- s.Protocol() + s.Close() + }) + + s, err := h2.NewStream(ctx, h1.ID(), "/testing/1.0.0", "/testing") + if err != nil { + t.Fatal(err) + } + + assertWait(t, connectedOn, "/testing/1.0.0") + + if s.Protocol() != "/testing/1.0.0" { + t.Fatal("shoould have gotten /testing") + } + s.Close() + + h1.Network().ConnsToPeer(h2.ID())[0].Close() + + time.Sleep(time.Millisecond * 50) // allow notifications to propogate + h1.RemoveStreamHandler("/testing/1.0.0") + h1.SetStreamHandler("/testing", func(s inet.Stream) { + connectedOn <- s.Protocol() + s.Close() + }) + + h2pi := h2.Peerstore().PeerInfo(h2.ID()) + if err := h1.Connect(ctx, h2pi); err != nil { + t.Fatal(err) + } + + s2, err := h2.NewStream(ctx, h1.ID(), "/testing/1.0.0", "/testing") + if err != nil { + t.Fatal(err) + } + + _, err = s2.Write(nil) + if err != nil { + t.Fatal(err) + } + + assertWait(t, connectedOn, "/testing") + + if s2.Protocol() != "/testing" { + t.Fatal("shoould have gotten /testing") + } + s2.Close() + +} diff --git a/p2p/protocol/identify/id.go b/p2p/protocol/identify/id.go index 67bacfd2aa56e1a3ea3f2eeebbaa56d60f6e61b8..ebd00248690eae2e5d373c95674cdb860ba8ba07 100644 --- a/p2p/protocol/identify/id.go +++ b/p2p/protocol/identify/id.go @@ -193,7 +193,7 @@ func (ids *IDService) consumeMessage(mes *pb.Identify, c inet.Conn) { p := c.RemotePeer() // mes.Protocols - ids.Host.Peerstore().AddProtocols(p, mes.Protocols...) + ids.Host.Peerstore().SetProtocols(p, mes.Protocols...) // mes.ObservedAddr ids.consumeObservedAddress(mes.GetObservedAddr(), c)