Commit f74add8a authored by Jeromy's avatar Jeromy Committed by Jeromy Johnson
Browse files

swarm: make stream.Protocol() return type protocol.ID

parent 36c66c0e
...@@ -116,10 +116,10 @@ func (h *BasicHost) newStreamHandler(s inet.Stream) { ...@@ -116,10 +116,10 @@ func (h *BasicHost) newStreamHandler(s inet.Stream) {
} }
return return
} }
s.SetProtocol(protocol.ID(protoID))
logStream := mstream.WrapStream(s, protocol.ID(protoID), h.bwc) logStream := mstream.WrapStream(s, h.bwc)
s.SetProtocol(protoID)
go handle(protoID, logStream) go handle(protoID, logStream)
} }
...@@ -155,7 +155,7 @@ func (h *BasicHost) IDService() *identify.IDService { ...@@ -155,7 +155,7 @@ func (h *BasicHost) IDService() *identify.IDService {
func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler inet.StreamHandler) { func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler inet.StreamHandler) {
h.Mux().AddHandler(string(pid), func(p string, rwc io.ReadWriteCloser) error { h.Mux().AddHandler(string(pid), func(p string, rwc io.ReadWriteCloser) error {
is := rwc.(inet.Stream) is := rwc.(inet.Stream)
is.SetProtocol(p) is.SetProtocol(protocol.ID(p))
handler(is) handler(is)
return nil return nil
}) })
...@@ -166,7 +166,7 @@ func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler inet.StreamHandler ...@@ -166,7 +166,7 @@ func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler inet.StreamHandler
func (h *BasicHost) SetStreamHandlerMatch(pid protocol.ID, m func(string) bool, handler inet.StreamHandler) { func (h *BasicHost) SetStreamHandlerMatch(pid protocol.ID, m func(string) bool, handler inet.StreamHandler) {
h.Mux().AddHandlerWithFunc(string(pid), m, func(p string, rwc io.ReadWriteCloser) error { h.Mux().AddHandlerWithFunc(string(pid), m, func(p string, rwc io.ReadWriteCloser) error {
is := rwc.(inet.Stream) is := rwc.(inet.Stream)
is.SetProtocol(p) is.SetProtocol(protocol.ID(p))
handler(is) handler(is)
return nil return nil
}) })
...@@ -187,27 +187,26 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I ...@@ -187,27 +187,26 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I
return h.newStream(ctx, p, pref) return h.newStream(ctx, p, pref)
} }
var lastErr error var protoStrs []string
for _, pid := range pids { for _, pid := range pids {
s, err := h.newStream(ctx, p, pid) protoStrs = append(protoStrs, string(pid))
if err != nil { }
lastErr = err
log.Infof("NewStream to %s for %s failed: %s", p, pid, err)
continue
}
_, err = s.Read(nil) s, err := h.Network().NewStream(ctx, p)
if err != nil { if err != nil {
lastErr = err return nil, err
log.Infof("NewStream to %s for %s failed (on read): %s", p, pid, err) }
continue
}
h.setPreferredProtocol(p, pid) selected, err := msmux.SelectOneOf(protoStrs, s)
return s, nil if err != nil {
s.Close()
return nil, err
} }
selpid := protocol.ID(selected)
s.SetProtocol(selpid)
h.setPreferredProtocol(p, selpid)
return nil, lastErr return mstream.WrapStream(s, h.bwc), nil
} }
func (h *BasicHost) preferredProtocol(p peer.ID, pids []protocol.ID) protocol.ID { func (h *BasicHost) preferredProtocol(p peer.ID, pids []protocol.ID) protocol.ID {
...@@ -257,9 +256,9 @@ func (h *BasicHost) newStream(ctx context.Context, p peer.ID, pid protocol.ID) ( ...@@ -257,9 +256,9 @@ func (h *BasicHost) newStream(ctx context.Context, p peer.ID, pid protocol.ID) (
return nil, err return nil, err
} }
s.SetProtocol(string(pid)) s.SetProtocol(pid)
logStream := mstream.WrapStream(s, pid, h.bwc) logStream := mstream.WrapStream(s, h.bwc)
lzcon := msmux.NewMSSelect(logStream, string(pid)) lzcon := msmux.NewMSSelect(logStream, string(pid))
return &streamWrapper{ return &streamWrapper{
......
...@@ -76,7 +76,7 @@ func getHostPair(ctx context.Context, t *testing.T) (host.Host, host.Host) { ...@@ -76,7 +76,7 @@ func getHostPair(ctx context.Context, t *testing.T) (host.Host, host.Host) {
return h1, h2 return h1, h2
} }
func assertWait(t *testing.T, c chan string, exp string) { func assertWait(t *testing.T, c chan protocol.ID, exp protocol.ID) {
select { select {
case proto := <-c: case proto := <-c:
if proto != exp { if proto != exp {
...@@ -99,7 +99,7 @@ func TestHostProtoPreference(t *testing.T) { ...@@ -99,7 +99,7 @@ func TestHostProtoPreference(t *testing.T) {
protoNew := protocol.ID("/testing/1.1.0") protoNew := protocol.ID("/testing/1.1.0")
protoMinor := protocol.ID("/testing/1.2.0") protoMinor := protocol.ID("/testing/1.2.0")
connectedOn := make(chan string, 16) connectedOn := make(chan protocol.ID, 16)
handler := func(s inet.Stream) { handler := func(s inet.Stream) {
connectedOn <- s.Protocol() connectedOn <- s.Protocol()
...@@ -113,10 +113,10 @@ func TestHostProtoPreference(t *testing.T) { ...@@ -113,10 +113,10 @@ func TestHostProtoPreference(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
assertWait(t, connectedOn, string(protoOld)) assertWait(t, connectedOn, protoOld)
s.Close() s.Close()
mfunc, err := host.MultistreamSemverMatcher(string(protoMinor)) mfunc, err := host.MultistreamSemverMatcher(protoMinor)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -135,7 +135,7 @@ func TestHostProtoPreference(t *testing.T) { ...@@ -135,7 +135,7 @@ func TestHostProtoPreference(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
assertWait(t, connectedOn, string(protoOld)) assertWait(t, connectedOn, protoOld)
s2.Close() s2.Close()
...@@ -144,12 +144,7 @@ func TestHostProtoPreference(t *testing.T) { ...@@ -144,12 +144,7 @@ func TestHostProtoPreference(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
_, err = s3.Read(nil) assertWait(t, connectedOn, protoMinor)
if err != nil {
t.Fatal(err)
}
assertWait(t, connectedOn, string(protoMinor))
s3.Close() s3.Close()
} }
...@@ -179,7 +174,7 @@ func TestHostProtoPreknowledge(t *testing.T) { ...@@ -179,7 +174,7 @@ func TestHostProtoPreknowledge(t *testing.T) {
h1 := testutil.GenHostSwarm(t, ctx) h1 := testutil.GenHostSwarm(t, ctx)
h2 := testutil.GenHostSwarm(t, ctx) h2 := testutil.GenHostSwarm(t, ctx)
conn := make(chan string, 16) conn := make(chan protocol.ID, 16)
handler := func(s inet.Stream) { handler := func(s inet.Stream) {
conn <- s.Protocol() conn <- s.Protocol()
s.Close() s.Close()
......
package host package host
import ( import (
"github.com/libp2p/go-libp2p/p2p/protocol"
"strings" "strings"
semver "github.com/coreos/go-semver/semver" semver "github.com/coreos/go-semver/semver"
) )
func MultistreamSemverMatcher(base string) (func(string) bool, error) { func MultistreamSemverMatcher(base protocol.ID) (func(string) bool, error) {
parts := strings.Split(base, "/") parts := strings.Split(string(base), "/")
vers, err := semver.NewVersion(parts[len(parts)-1]) vers, err := semver.NewVersion(parts[len(parts)-1])
if err != nil { if err != nil {
return nil, err return nil, err
......
...@@ -19,18 +19,18 @@ type meteredStream struct { ...@@ -19,18 +19,18 @@ type meteredStream struct {
mesRecv metrics.StreamMeterCallback mesRecv metrics.StreamMeterCallback
} }
func newMeteredStream(base inet.Stream, pid protocol.ID, p peer.ID, recvCB, sentCB metrics.StreamMeterCallback) inet.Stream { func newMeteredStream(base inet.Stream, p peer.ID, recvCB, sentCB metrics.StreamMeterCallback) inet.Stream {
return &meteredStream{ return &meteredStream{
Stream: base, Stream: base,
mesSent: sentCB, mesSent: sentCB,
mesRecv: recvCB, mesRecv: recvCB,
protoKey: pid, protoKey: base.Protocol(),
peerKey: p, peerKey: p,
} }
} }
func WrapStream(base inet.Stream, pid protocol.ID, bwc metrics.Reporter) inet.Stream { func WrapStream(base inet.Stream, bwc metrics.Reporter) inet.Stream {
return newMeteredStream(base, pid, base.Conn().RemotePeer(), bwc.LogRecvMessageStream, bwc.LogSentMessageStream) return newMeteredStream(base, base.Conn().RemotePeer(), bwc.LogRecvMessageStream, bwc.LogSentMessageStream)
} }
func (s *meteredStream) Read(b []byte) (int, error) { func (s *meteredStream) Read(b []byte) (int, error) {
......
...@@ -24,6 +24,10 @@ func (fs *FakeStream) Write(b []byte) (int, error) { ...@@ -24,6 +24,10 @@ func (fs *FakeStream) Write(b []byte) (int, error) {
return len(b), nil return len(b), nil
} }
func (fs *FakeStream) Protocol() protocol.ID {
return "TEST"
}
func TestCallbacksWork(t *testing.T) { func TestCallbacksWork(t *testing.T) {
fake := new(FakeStream) fake := new(FakeStream)
...@@ -38,7 +42,7 @@ func TestCallbacksWork(t *testing.T) { ...@@ -38,7 +42,7 @@ func TestCallbacksWork(t *testing.T) {
recv += n recv += n
} }
ms := newMeteredStream(fake, protocol.ID("TEST"), peer.ID("PEER"), recvCB, sentCB) ms := newMeteredStream(fake, peer.ID("PEER"), recvCB, sentCB)
toWrite := int64(100000) toWrite := int64(100000)
toRead := int64(100000) toRead := int64(100000)
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
ma "github.com/jbenet/go-multiaddr" ma "github.com/jbenet/go-multiaddr"
"github.com/jbenet/goprocess" "github.com/jbenet/goprocess"
conn "github.com/libp2p/go-libp2p/p2p/net/conn" conn "github.com/libp2p/go-libp2p/p2p/net/conn"
protocol "github.com/libp2p/go-libp2p/p2p/protocol"
context "golang.org/x/net/context" context "golang.org/x/net/context"
) )
...@@ -26,8 +27,8 @@ type Stream interface { ...@@ -26,8 +27,8 @@ type Stream interface {
io.Writer io.Writer
io.Closer io.Closer
Protocol() string Protocol() protocol.ID
SetProtocol(string) SetProtocol(protocol.ID)
// Conn returns the connection this stream is part of. // Conn returns the connection this stream is part of.
Conn() Conn Conn() Conn
......
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
process "github.com/jbenet/goprocess" process "github.com/jbenet/goprocess"
inet "github.com/libp2p/go-libp2p/p2p/net" inet "github.com/libp2p/go-libp2p/p2p/net"
protocol "github.com/libp2p/go-libp2p/p2p/protocol"
) )
// stream implements inet.Stream // stream implements inet.Stream
...@@ -17,7 +18,7 @@ type stream struct { ...@@ -17,7 +18,7 @@ type stream struct {
toDeliver chan *transportObject toDeliver chan *transportObject
proc process.Process proc process.Process
protocol string protocol protocol.ID
} }
type transportObject struct { type transportObject struct {
...@@ -50,11 +51,11 @@ func (s *stream) Write(p []byte) (n int, err error) { ...@@ -50,11 +51,11 @@ func (s *stream) Write(p []byte) (n int, err error) {
return len(p), nil return len(p), nil
} }
func (s *stream) Protocol() string { func (s *stream) Protocol() protocol.ID {
return s.protocol return s.protocol
} }
func (s *stream) SetProtocol(proto string) { func (s *stream) SetProtocol(proto protocol.ID) {
s.protocol = proto s.protocol = proto
} }
......
...@@ -2,6 +2,7 @@ package swarm ...@@ -2,6 +2,7 @@ package swarm
import ( import (
inet "github.com/libp2p/go-libp2p/p2p/net" inet "github.com/libp2p/go-libp2p/p2p/net"
protocol "github.com/libp2p/go-libp2p/p2p/protocol"
ps "github.com/jbenet/go-peerstream" ps "github.com/jbenet/go-peerstream"
) )
...@@ -10,7 +11,7 @@ import ( ...@@ -10,7 +11,7 @@ import (
// our Conn and Swarm (instead of just the ps.Conn and ps.Swarm) // our Conn and Swarm (instead of just the ps.Conn and ps.Swarm)
type Stream struct { type Stream struct {
stream *ps.Stream stream *ps.Stream
protocol string protocol protocol.ID
} }
// Stream returns the underlying peerstream.Stream // Stream returns the underlying peerstream.Stream
...@@ -44,11 +45,11 @@ func (s *Stream) Close() error { ...@@ -44,11 +45,11 @@ func (s *Stream) Close() error {
return s.stream.Close() return s.stream.Close()
} }
func (s *Stream) Protocol() string { func (s *Stream) Protocol() protocol.ID {
return s.protocol return s.protocol
} }
func (s *Stream) SetProtocol(p string) { func (s *Stream) SetProtocol(p protocol.ID) {
s.protocol = p s.protocol = p
} }
......
...@@ -86,8 +86,10 @@ func (ids *IDService) IdentifyConn(c inet.Conn) { ...@@ -86,8 +86,10 @@ func (ids *IDService) IdentifyConn(c inet.Conn) {
return return
} }
s.SetProtocol(ID)
bwc := ids.Host.GetBandwidthReporter() bwc := ids.Host.GetBandwidthReporter()
s = mstream.WrapStream(s, ID, bwc) s = mstream.WrapStream(s, bwc)
// ok give the response to our handler. // ok give the response to our handler.
if err := msmux.SelectProtoOrFail(ID, s); err != nil { if err := msmux.SelectProtoOrFail(ID, s); err != nil {
...@@ -115,7 +117,7 @@ func (ids *IDService) RequestHandler(s inet.Stream) { ...@@ -115,7 +117,7 @@ func (ids *IDService) RequestHandler(s inet.Stream) {
c := s.Conn() c := s.Conn()
bwc := ids.Host.GetBandwidthReporter() bwc := ids.Host.GetBandwidthReporter()
s = mstream.WrapStream(s, ID, bwc) s = mstream.WrapStream(s, bwc)
w := ggio.NewDelimitedWriter(s) w := ggio.NewDelimitedWriter(s)
mes := pb.Identify{} mes := pb.Identify{}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment