diff --git a/p2p/protocol/ping/ping.go b/p2p/protocol/ping/ping.go index d6e571e0f31f9e28382df0b5116b8e5ad1bc3fa8..54464f3080b125da15cdce82e0682aa203b8dab3 100644 --- a/p2p/protocol/ping/ping.go +++ b/p2p/protocol/ping/ping.go @@ -33,33 +33,42 @@ func NewPingService(h host.Host) *PingService { } func (p *PingService) PingHandler(s inet.Stream) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - buf := make([]byte, PingSize) + errCh := make(chan error, 1) + defer close(errCh) timer := time.NewTimer(pingTimeout) defer timer.Stop() go func() { select { case <-timer.C: - case <-ctx.Done(): + log.Debug("ping timeout") + s.Reset() + case err, ok := <-errCh: + if ok { + log.Debug(err) + if err == io.EOF { + s.Close() + } else { + s.Reset() + } + } else { + log.Error("ping loop failed without error") + } } - - s.Close() }() for { _, err := io.ReadFull(s, buf) if err != nil { - log.Debug(err) + errCh <- err return } _, err = s.Write(buf) if err != nil { - log.Debug(err) + errCh <- err return } @@ -84,6 +93,7 @@ func (ps *PingService) Ping(ctx context.Context, p peer.ID) (<-chan time.Duratio default: t, err := ping(s) if err != nil { + s.Reset() log.Debugf("ping error: %s", err) return }