diff --git a/p2p/discovery/mdns.go b/p2p/discovery/mdns.go index bc9c1f66ce46d23ad4d03f0944ea552c755ebbaf..85305c477db8e18d496f10cf57f7efe360693ebd 100644 --- a/p2p/discovery/mdns.go +++ b/p2p/discovery/mdns.go @@ -1,6 +1,7 @@ package discovery import ( + "context" "errors" "io" "io/ioutil" @@ -60,7 +61,7 @@ func getDialableListenAddrs(ph host.Host) ([]*net.TCPAddr, error) { return out, nil } -func NewMdnsService(peerhost host.Host, interval time.Duration) (Service, error) { +func NewMdnsService(ctx context.Context, peerhost host.Host, interval time.Duration) (Service, error) { // TODO: dont let mdns use logging... golog.SetOutput(ioutil.Discard) @@ -99,7 +100,7 @@ func NewMdnsService(peerhost host.Host, interval time.Duration) (Service, error) interval: interval, } - go s.pollForEntries() + go s.pollForEntries(ctx) return s, nil } @@ -108,31 +109,42 @@ func (m *mdnsService) Close() error { return m.server.Shutdown() } -func (m *mdnsService) pollForEntries() { +func (m *mdnsService) pollForEntries(ctx context.Context) { + ticker := time.NewTicker(m.interval) - for range ticker.C { - entriesCh := make(chan *mdns.ServiceEntry, 16) - go func() { - for entry := range entriesCh { - m.handleEntry(entry) + for { + select { + case <-ticker.C: + entriesCh := make(chan *mdns.ServiceEntry, 16) + go func() { + for entry := range entriesCh { + m.handleEntry(entry) + } + }() + + log.Debug("starting mdns query") + qp := &mdns.QueryParam{ + Domain: "local", + Entries: entriesCh, + Service: ServiceTag, + Timeout: time.Second * 5, } - }() - - qp := mdns.QueryParam{} - qp.Domain = "local" - qp.Entries = entriesCh - qp.Service = ServiceTag - qp.Timeout = time.Second * 5 - err := mdns.Query(&qp) - if err != nil { - log.Error("mdns lookup error: ", err) + err := mdns.Query(qp) + if err != nil { + log.Error("mdns lookup error: ", err) + } + close(entriesCh) + log.Debug("mdns query complete") + case <-ctx.Done(): + log.Debug("mdns service halting") + return } - close(entriesCh) } } func (m *mdnsService) handleEntry(e *mdns.ServiceEntry) { + log.Debugf("Handling MDNS entry: %s:%d %s", e.AddrV4, e.Port, e.Info) mpeer, err := peer.IDB58Decode(e.Info) if err != nil { log.Warning("Error parsing peer ID from mdns entry: ", err) @@ -140,6 +152,7 @@ func (m *mdnsService) handleEntry(e *mdns.ServiceEntry) { } if mpeer == m.host.ID() { + log.Debug("got our own mdns entry, skipping") return } @@ -159,7 +172,7 @@ func (m *mdnsService) handleEntry(e *mdns.ServiceEntry) { m.lk.Lock() for _, n := range m.notifees { - n.HandlePeerFound(pi) + go n.HandlePeerFound(pi) } m.lk.Unlock() } diff --git a/p2p/discovery/mdns_test.go b/p2p/discovery/mdns_test.go new file mode 100644 index 0000000000000000000000000000000000000000..f0282facffa263d8be1c0922fc9d5621e28e8aea --- /dev/null +++ b/p2p/discovery/mdns_test.go @@ -0,0 +1,51 @@ +package discovery + +import ( + "context" + "testing" + "time" + + host "github.com/libp2p/go-libp2p/p2p/host" + netutil "github.com/libp2p/go-libp2p/p2p/test/util" + + pstore "github.com/ipfs/go-libp2p-peerstore" +) + +type DiscoveryNotifee struct { + h host.Host +} + +func (n *DiscoveryNotifee) HandlePeerFound(pi pstore.PeerInfo) { + n.h.Connect(context.Background(), pi) +} + +func TestMdnsDiscovery(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + a := netutil.GenHostSwarm(t, ctx) + b := netutil.GenHostSwarm(t, ctx) + + sa, err := NewMdnsService(ctx, a, time.Second) + if err != nil { + t.Fatal(err) + } + + sb, err := NewMdnsService(ctx, b, time.Second) + if err != nil { + t.Fatal(err) + } + + _ = sb + + n := &DiscoveryNotifee{a} + + sa.RegisterNotifee(n) + + time.Sleep(time.Second * 2) + + err = a.Connect(ctx, pstore.PeerInfo{ID: b.ID()}) + if err != nil { + t.Fatal(err) + } +}