diff --git a/examples/hosts/main.go b/examples/hosts/main.go index f92be28b3a446bc3a0b931f83bf26c2fb029aac0..2b51fb1cafbc7cbe30ccde9dfd6247d358df4040 100644 --- a/examples/hosts/main.go +++ b/examples/hosts/main.go @@ -5,7 +5,6 @@ import ( "fmt" "io/ioutil" "log" - "time" pstore "github.com/ipfs/go-libp2p-peerstore" host "github.com/ipfs/go-libp2p/p2p/host" diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index b7e94632e75e9e42472d3adcd74d20a904c85fae..6e227c48c256e5951e1269945d95337058feed23 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -114,7 +114,8 @@ func (h *BasicHost) newStreamHandler(s inet.Stream) { logStream := mstream.WrapStream(s, protocol.ID(protoID), h.bwc) - go handle(logStream) + s.SetProtocol(protoID) + go handle(protoID, logStream) } // ID returns the (local) peer.ID associated with this Host @@ -147,8 +148,10 @@ func (h *BasicHost) IDService() *identify.IDService { // host.Mux().SetHandler(proto, handler) // (Threadsafe) func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler inet.StreamHandler) { - h.Mux().AddHandler(string(pid), func(rwc io.ReadWriteCloser) error { - handler(rwc.(inet.Stream)) + h.Mux().AddHandler(string(pid), func(p string, rwc io.ReadWriteCloser) error { + is := rwc.(inet.Stream) + is.SetProtocol(p) + handler(is) return nil }) } diff --git a/p2p/net/conn/dial_test.go b/p2p/net/conn/dial_test.go index f26f76f3bb3b94452f8ac0f1a23e1a39a52b2ce9..84ddbee778d2c62985952d2f9c363d47f8e9fdda 100644 --- a/p2p/net/conn/dial_test.go +++ b/p2p/net/conn/dial_test.go @@ -23,7 +23,7 @@ import ( ) func goroFilter(r *grc.Goroutine) bool { - return strings.Contains(r.Function, "go-log.") + return strings.Contains(r.Function, "go-log.") || strings.Contains(r.Stack[0], "testing.(*T).Run") } func echoListen(ctx context.Context, listener Listener) { diff --git a/p2p/net/interface.go b/p2p/net/interface.go index 81d292eab3781d9f6fac4ebf0914f8e7e0b1d342..474d879211c627b22270e6d031a5d7db00eef6ce 100644 --- a/p2p/net/interface.go +++ b/p2p/net/interface.go @@ -26,6 +26,9 @@ type Stream interface { io.Writer io.Closer + Protocol() string + SetProtocol(string) + // Conn returns the connection this stream is part of. Conn() Conn } diff --git a/p2p/net/mock/mock_stream.go b/p2p/net/mock/mock_stream.go index 78b4d750ec3bda3d0f3f4df3a2c6bb35b75268a8..62923813c4dc97d29a960fc052c035bc6ebd10ba 100644 --- a/p2p/net/mock/mock_stream.go +++ b/p2p/net/mock/mock_stream.go @@ -16,6 +16,8 @@ type stream struct { conn *conn toDeliver chan *transportObject proc process.Process + + protocol string } type transportObject struct { @@ -48,6 +50,14 @@ func (s *stream) Write(p []byte) (n int, err error) { return len(p), nil } +func (s *stream) Protocol() string { + return s.protocol +} + +func (s *stream) SetProtocol(proto string) { + s.protocol = proto +} + func (s *stream) Close() error { return s.proc.Close() } diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index 15e1ec5e1b868a1a84df00a5c4d0231566273f59..1fcf0f1c14acc09a4de5f502de677149a15507c8 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -340,9 +340,9 @@ func (n *ps2netNotifee) Disconnected(c *ps.Conn) { } func (n *ps2netNotifee) OpenedStream(s *ps.Stream) { - n.not.OpenedStream(n.net, inet.Stream((*Stream)(s))) + n.not.OpenedStream(n.net, &Stream{stream: s}) } func (n *ps2netNotifee) ClosedStream(s *ps.Stream) { - n.not.ClosedStream(n.net, inet.Stream((*Stream)(s))) + n.not.ClosedStream(n.net, &Stream{stream: s}) } diff --git a/p2p/net/swarm/swarm_notif_test.go b/p2p/net/swarm/swarm_notif_test.go index aa4fca2448885783fa9cd01b565983facd33e37d..baced67e61ee17e9f5445503d4d14d2634755c4e 100644 --- a/p2p/net/swarm/swarm_notif_test.go +++ b/p2p/net/swarm/swarm_notif_test.go @@ -10,6 +10,12 @@ import ( context "golang.org/x/net/context" ) +func streamsSame(a, b inet.Stream) bool { + sa := a.(*Stream) + sb := b.(*Stream) + return sa.Stream() == sb.Stream() +} + func TestNotifications(t *testing.T) { ctx := context.Background() swarms := makeSwarms(ctx, t, 5) @@ -98,7 +104,7 @@ func TestNotifications(t *testing.T) { case <-time.After(timeout): t.Fatal("timeout") } - if s != s2 { + if !streamsSame(s, s2) { t.Fatal("got incorrect stream", s.Conn(), s2.Conn()) } @@ -108,7 +114,7 @@ func TestNotifications(t *testing.T) { case <-time.After(timeout): t.Fatal("timeout") } - if s != s2 { + if !streamsSame(s, s2) { t.Fatal("got incorrect stream", s.Conn(), s2.Conn()) } } diff --git a/p2p/net/swarm/swarm_stream.go b/p2p/net/swarm/swarm_stream.go index 7965d2743984d85d34e2e8c3ae61dcf897529d41..16aa6e07c9ec8864db6ab637df8658a3b7949d26 100644 --- a/p2p/net/swarm/swarm_stream.go +++ b/p2p/net/swarm/swarm_stream.go @@ -8,11 +8,14 @@ import ( // a Stream is a wrapper around a ps.Stream that exposes a way to get // our Conn and Swarm (instead of just the ps.Conn and ps.Swarm) -type Stream ps.Stream +type Stream struct { + stream *ps.Stream + protocol string +} // Stream returns the underlying peerstream.Stream func (s *Stream) Stream() *ps.Stream { - return (*ps.Stream)(s) + return s.stream } // Conn returns the Conn associated with this Stream, as an inet.Conn @@ -22,27 +25,37 @@ func (s *Stream) Conn() inet.Conn { // SwarmConn returns the Conn associated with this Stream, as a *Conn func (s *Stream) SwarmConn() *Conn { - return (*Conn)(s.Stream().Conn()) + return (*Conn)(s.stream.Conn()) } // Read reads bytes from a stream. func (s *Stream) Read(p []byte) (n int, err error) { - return s.Stream().Read(p) + return s.stream.Read(p) } // Write writes bytes to a stream, flushing for each call. func (s *Stream) Write(p []byte) (n int, err error) { - return s.Stream().Write(p) + return s.stream.Write(p) } // Close closes the stream, indicating this side is finished // with the stream. func (s *Stream) Close() error { - return s.Stream().Close() + return s.stream.Close() +} + +func (s *Stream) Protocol() string { + return s.protocol +} + +func (s *Stream) SetProtocol(p string) { + s.protocol = p } func wrapStream(pss *ps.Stream) *Stream { - return (*Stream)(pss) + return &Stream{ + stream: pss, + } } func wrapStreams(st []*ps.Stream) []*Stream { diff --git a/package.json b/package.json index e2ce569b5a76d1c6286e5017ccd1d1d86ea1b20b..b28b1a962339f966c652071da614f66bc715fa3c 100644 --- a/package.json +++ b/package.json @@ -34,9 +34,9 @@ "version": "1.0.0" }, { - "hash": "Qmf91yhgRLo2dhhbc5zZ7TxjMaR1oxaWaoc9zRZdi1kU4a", + "hash": "Qmc8WfU6Ci9e1qvTNYE3EUwrHEXfpxY7dNrWtVtjpYcp2P", "name": "go-multistream", - "version": "0.0.0" + "version": "0.1.0" }, { "hash": "QmNLvkCDV6ZjUJsEwGNporYBuZdhWT6q7TBVYQwwRv12HT", @@ -171,9 +171,9 @@ }, { "author": "whyrusleeping", - "hash": "QmVcmcQE9eX4HQ8QwhVXpoHt3ennG7d299NDYFq9D1Uqa1", + "hash": "QmRbnoT3xJXpi37Vc11e6VYV4RXKWUMsZtfbBQkR43377P", "name": "go-smux-multistream", - "version": "1.0.0" + "version": "1.1.0" }, { "author": "whyrusleeping",