diff --git a/p2p/net/mock/mock.go b/p2p/net/mock/mock.go index 8760b4d48902b180a9d158b8cca42287c9d7ee83..55256f867fd78055e8351642e008c758daec4cdf 100644 --- a/p2p/net/mock/mock.go +++ b/p2p/net/mock/mock.go @@ -44,14 +44,9 @@ func FullMeshConnected(ctx context.Context, n int) (Mocknet, error) { return nil, err } - nets := m.Nets() - for _, n1 := range nets { - for _, n2 := range nets { - if _, err := m.ConnectNets(n1, n2); err != nil { - return nil, err - } - } + err = m.ConnectAllButSelf() + if err != nil { + return nil, err } - return m, nil } diff --git a/p2p/net/mock/mock_peernet.go b/p2p/net/mock/mock_peernet.go index 3c5710bf915476b67dabc2c36ff253fbc35ad2ea..52331016ca7ce8887a9ba2da7ad40dcb7ca695bb 100644 --- a/p2p/net/mock/mock_peernet.go +++ b/p2p/net/mock/mock_peernet.go @@ -118,6 +118,10 @@ func (pn *peernet) DialPeer(ctx context.Context, p peer.ID) (inet.Conn, error) { } func (pn *peernet) connect(p peer.ID) (*conn, error) { + if p == pn.peer { + return nil, fmt.Errorf("attempted to dial self %s", p) + } + // first, check if we already have live connections pn.RLock() cs, found := pn.connsByPeer[p] diff --git a/p2p/net/mock/mock_test.go b/p2p/net/mock/mock_test.go index 0f888a7098c100744eb5e5bc3320126afd2d3289..d6d3c72680a68f4f29469657aca41c2af665e0a0 100644 --- a/p2p/net/mock/mock_test.go +++ b/p2p/net/mock/mock_test.go @@ -225,14 +225,14 @@ func TestNetworkSetup(t *testing.T) { t.Error("should not be able to connect") } - // connect p1->p1 (should work) - if _, err := n1.DialPeer(ctx, p1); err != nil { - t.Error("p1 should be able to dial self.", err) + // connect p1->p1 (should fail) + if _, err := n1.DialPeer(ctx, p1); err == nil { + t.Error("p1 shouldn't be able to dial self") } // and a stream too - if _, err := n1.NewStream(ctx, p1); err != nil { - t.Error(err) + if _, err := n1.NewStream(ctx, p1); err == nil { + t.Error("p1 shouldn't be able to dial self") } // connect p1->p2 @@ -383,8 +383,11 @@ func TestStreamsStress(t *testing.T) { wg.Add(1) go func(i int) { defer wg.Done() - from := rand.Intn(len(hosts)) - to := rand.Intn(len(hosts)) + var from, to int + for from == to { + from = rand.Intn(len(hosts)) + to = rand.Intn(len(hosts)) + } s, err := hosts[from].NewStream(ctx, hosts[to].ID(), protocol.TestingID) if err != nil { log.Debugf("%d (%s) %d (%s)", from, hosts[from], to, hosts[to])