Commit a8e25bf2 authored by Steven Allen's avatar Steven Allen
Browse files

mocknet: refuse to connect to self

The swarm does this as well and most of our services will fail if we don't have
this.
parent da772d14
...@@ -44,14 +44,9 @@ func FullMeshConnected(ctx context.Context, n int) (Mocknet, error) { ...@@ -44,14 +44,9 @@ func FullMeshConnected(ctx context.Context, n int) (Mocknet, error) {
return nil, err return nil, err
} }
nets := m.Nets() err = m.ConnectAllButSelf()
for _, n1 := range nets { if err != nil {
for _, n2 := range nets { return nil, err
if _, err := m.ConnectNets(n1, n2); err != nil {
return nil, err
}
}
} }
return m, nil return m, nil
} }
...@@ -118,6 +118,10 @@ func (pn *peernet) DialPeer(ctx context.Context, p peer.ID) (inet.Conn, error) { ...@@ -118,6 +118,10 @@ func (pn *peernet) DialPeer(ctx context.Context, p peer.ID) (inet.Conn, error) {
} }
func (pn *peernet) connect(p peer.ID) (*conn, error) { func (pn *peernet) connect(p peer.ID) (*conn, error) {
if p == pn.peer {
return nil, fmt.Errorf("attempted to dial self %s", p)
}
// first, check if we already have live connections // first, check if we already have live connections
pn.RLock() pn.RLock()
cs, found := pn.connsByPeer[p] cs, found := pn.connsByPeer[p]
......
...@@ -225,14 +225,14 @@ func TestNetworkSetup(t *testing.T) { ...@@ -225,14 +225,14 @@ func TestNetworkSetup(t *testing.T) {
t.Error("should not be able to connect") t.Error("should not be able to connect")
} }
// connect p1->p1 (should work) // connect p1->p1 (should fail)
if _, err := n1.DialPeer(ctx, p1); err != nil { if _, err := n1.DialPeer(ctx, p1); err == nil {
t.Error("p1 should be able to dial self.", err) t.Error("p1 shouldn't be able to dial self")
} }
// and a stream too // and a stream too
if _, err := n1.NewStream(ctx, p1); err != nil { if _, err := n1.NewStream(ctx, p1); err == nil {
t.Error(err) t.Error("p1 shouldn't be able to dial self")
} }
// connect p1->p2 // connect p1->p2
...@@ -383,8 +383,11 @@ func TestStreamsStress(t *testing.T) { ...@@ -383,8 +383,11 @@ func TestStreamsStress(t *testing.T) {
wg.Add(1) wg.Add(1)
go func(i int) { go func(i int) {
defer wg.Done() defer wg.Done()
from := rand.Intn(len(hosts)) var from, to int
to := rand.Intn(len(hosts)) for from == to {
from = rand.Intn(len(hosts))
to = rand.Intn(len(hosts))
}
s, err := hosts[from].NewStream(ctx, hosts[to].ID(), protocol.TestingID) s, err := hosts[from].NewStream(ctx, hosts[to].ID(), protocol.TestingID)
if err != nil { if err != nil {
log.Debugf("%d (%s) %d (%s)", from, hosts[from], to, hosts[to]) log.Debugf("%d (%s) %d (%s)", from, hosts[from], to, hosts[to])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment