Commit f74add8a authored by Jeromy's avatar Jeromy Committed by Jeromy Johnson
Browse files

swarm: make stream.Protocol() return type protocol.ID

parent 36c66c0e
......@@ -116,10 +116,10 @@ func (h *BasicHost) newStreamHandler(s inet.Stream) {
}
return
}
s.SetProtocol(protocol.ID(protoID))
logStream := mstream.WrapStream(s, protocol.ID(protoID), h.bwc)
logStream := mstream.WrapStream(s, h.bwc)
s.SetProtocol(protoID)
go handle(protoID, logStream)
}
......@@ -155,7 +155,7 @@ func (h *BasicHost) IDService() *identify.IDService {
func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler inet.StreamHandler) {
h.Mux().AddHandler(string(pid), func(p string, rwc io.ReadWriteCloser) error {
is := rwc.(inet.Stream)
is.SetProtocol(p)
is.SetProtocol(protocol.ID(p))
handler(is)
return nil
})
......@@ -166,7 +166,7 @@ func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler inet.StreamHandler
func (h *BasicHost) SetStreamHandlerMatch(pid protocol.ID, m func(string) bool, handler inet.StreamHandler) {
h.Mux().AddHandlerWithFunc(string(pid), m, func(p string, rwc io.ReadWriteCloser) error {
is := rwc.(inet.Stream)
is.SetProtocol(p)
is.SetProtocol(protocol.ID(p))
handler(is)
return nil
})
......@@ -187,27 +187,26 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I
return h.newStream(ctx, p, pref)
}
var lastErr error
var protoStrs []string
for _, pid := range pids {
s, err := h.newStream(ctx, p, pid)
if err != nil {
lastErr = err
log.Infof("NewStream to %s for %s failed: %s", p, pid, err)
continue
protoStrs = append(protoStrs, string(pid))
}
_, err = s.Read(nil)
s, err := h.Network().NewStream(ctx, p)
if err != nil {
lastErr = err
log.Infof("NewStream to %s for %s failed (on read): %s", p, pid, err)
continue
return nil, err
}
h.setPreferredProtocol(p, pid)
return s, nil
selected, err := msmux.SelectOneOf(protoStrs, s)
if err != nil {
s.Close()
return nil, err
}
selpid := protocol.ID(selected)
s.SetProtocol(selpid)
h.setPreferredProtocol(p, selpid)
return nil, lastErr
return mstream.WrapStream(s, h.bwc), nil
}
func (h *BasicHost) preferredProtocol(p peer.ID, pids []protocol.ID) protocol.ID {
......@@ -257,9 +256,9 @@ func (h *BasicHost) newStream(ctx context.Context, p peer.ID, pid protocol.ID) (
return nil, err
}
s.SetProtocol(string(pid))
s.SetProtocol(pid)
logStream := mstream.WrapStream(s, pid, h.bwc)
logStream := mstream.WrapStream(s, h.bwc)
lzcon := msmux.NewMSSelect(logStream, string(pid))
return &streamWrapper{
......
......@@ -76,7 +76,7 @@ func getHostPair(ctx context.Context, t *testing.T) (host.Host, host.Host) {
return h1, h2
}
func assertWait(t *testing.T, c chan string, exp string) {
func assertWait(t *testing.T, c chan protocol.ID, exp protocol.ID) {
select {
case proto := <-c:
if proto != exp {
......@@ -99,7 +99,7 @@ func TestHostProtoPreference(t *testing.T) {
protoNew := protocol.ID("/testing/1.1.0")
protoMinor := protocol.ID("/testing/1.2.0")
connectedOn := make(chan string, 16)
connectedOn := make(chan protocol.ID, 16)
handler := func(s inet.Stream) {
connectedOn <- s.Protocol()
......@@ -113,10 +113,10 @@ func TestHostProtoPreference(t *testing.T) {
t.Fatal(err)
}
assertWait(t, connectedOn, string(protoOld))
assertWait(t, connectedOn, protoOld)
s.Close()
mfunc, err := host.MultistreamSemverMatcher(string(protoMinor))
mfunc, err := host.MultistreamSemverMatcher(protoMinor)
if err != nil {
t.Fatal(err)
}
......@@ -135,7 +135,7 @@ func TestHostProtoPreference(t *testing.T) {
t.Fatal(err)
}
assertWait(t, connectedOn, string(protoOld))
assertWait(t, connectedOn, protoOld)
s2.Close()
......@@ -144,12 +144,7 @@ func TestHostProtoPreference(t *testing.T) {
t.Fatal(err)
}
_, err = s3.Read(nil)
if err != nil {
t.Fatal(err)
}
assertWait(t, connectedOn, string(protoMinor))
assertWait(t, connectedOn, protoMinor)
s3.Close()
}
......@@ -179,7 +174,7 @@ func TestHostProtoPreknowledge(t *testing.T) {
h1 := testutil.GenHostSwarm(t, ctx)
h2 := testutil.GenHostSwarm(t, ctx)
conn := make(chan string, 16)
conn := make(chan protocol.ID, 16)
handler := func(s inet.Stream) {
conn <- s.Protocol()
s.Close()
......
package host
import (
"github.com/libp2p/go-libp2p/p2p/protocol"
"strings"
semver "github.com/coreos/go-semver/semver"
)
func MultistreamSemverMatcher(base string) (func(string) bool, error) {
parts := strings.Split(base, "/")
func MultistreamSemverMatcher(base protocol.ID) (func(string) bool, error) {
parts := strings.Split(string(base), "/")
vers, err := semver.NewVersion(parts[len(parts)-1])
if err != nil {
return nil, err
......
......@@ -19,18 +19,18 @@ type meteredStream struct {
mesRecv metrics.StreamMeterCallback
}
func newMeteredStream(base inet.Stream, pid protocol.ID, p peer.ID, recvCB, sentCB metrics.StreamMeterCallback) inet.Stream {
func newMeteredStream(base inet.Stream, p peer.ID, recvCB, sentCB metrics.StreamMeterCallback) inet.Stream {
return &meteredStream{
Stream: base,
mesSent: sentCB,
mesRecv: recvCB,
protoKey: pid,
protoKey: base.Protocol(),
peerKey: p,
}
}
func WrapStream(base inet.Stream, pid protocol.ID, bwc metrics.Reporter) inet.Stream {
return newMeteredStream(base, pid, base.Conn().RemotePeer(), bwc.LogRecvMessageStream, bwc.LogSentMessageStream)
func WrapStream(base inet.Stream, bwc metrics.Reporter) inet.Stream {
return newMeteredStream(base, base.Conn().RemotePeer(), bwc.LogRecvMessageStream, bwc.LogSentMessageStream)
}
func (s *meteredStream) Read(b []byte) (int, error) {
......
......@@ -24,6 +24,10 @@ func (fs *FakeStream) Write(b []byte) (int, error) {
return len(b), nil
}
func (fs *FakeStream) Protocol() protocol.ID {
return "TEST"
}
func TestCallbacksWork(t *testing.T) {
fake := new(FakeStream)
......@@ -38,7 +42,7 @@ func TestCallbacksWork(t *testing.T) {
recv += n
}
ms := newMeteredStream(fake, protocol.ID("TEST"), peer.ID("PEER"), recvCB, sentCB)
ms := newMeteredStream(fake, peer.ID("PEER"), recvCB, sentCB)
toWrite := int64(100000)
toRead := int64(100000)
......
......@@ -8,6 +8,7 @@ import (
ma "github.com/jbenet/go-multiaddr"
"github.com/jbenet/goprocess"
conn "github.com/libp2p/go-libp2p/p2p/net/conn"
protocol "github.com/libp2p/go-libp2p/p2p/protocol"
context "golang.org/x/net/context"
)
......@@ -26,8 +27,8 @@ type Stream interface {
io.Writer
io.Closer
Protocol() string
SetProtocol(string)
Protocol() protocol.ID
SetProtocol(protocol.ID)
// Conn returns the connection this stream is part of.
Conn() Conn
......
......@@ -7,6 +7,7 @@ import (
process "github.com/jbenet/goprocess"
inet "github.com/libp2p/go-libp2p/p2p/net"
protocol "github.com/libp2p/go-libp2p/p2p/protocol"
)
// stream implements inet.Stream
......@@ -17,7 +18,7 @@ type stream struct {
toDeliver chan *transportObject
proc process.Process
protocol string
protocol protocol.ID
}
type transportObject struct {
......@@ -50,11 +51,11 @@ func (s *stream) Write(p []byte) (n int, err error) {
return len(p), nil
}
func (s *stream) Protocol() string {
func (s *stream) Protocol() protocol.ID {
return s.protocol
}
func (s *stream) SetProtocol(proto string) {
func (s *stream) SetProtocol(proto protocol.ID) {
s.protocol = proto
}
......
......@@ -2,6 +2,7 @@ package swarm
import (
inet "github.com/libp2p/go-libp2p/p2p/net"
protocol "github.com/libp2p/go-libp2p/p2p/protocol"
ps "github.com/jbenet/go-peerstream"
)
......@@ -10,7 +11,7 @@ import (
// our Conn and Swarm (instead of just the ps.Conn and ps.Swarm)
type Stream struct {
stream *ps.Stream
protocol string
protocol protocol.ID
}
// Stream returns the underlying peerstream.Stream
......@@ -44,11 +45,11 @@ func (s *Stream) Close() error {
return s.stream.Close()
}
func (s *Stream) Protocol() string {
func (s *Stream) Protocol() protocol.ID {
return s.protocol
}
func (s *Stream) SetProtocol(p string) {
func (s *Stream) SetProtocol(p protocol.ID) {
s.protocol = p
}
......
......@@ -86,8 +86,10 @@ func (ids *IDService) IdentifyConn(c inet.Conn) {
return
}
s.SetProtocol(ID)
bwc := ids.Host.GetBandwidthReporter()
s = mstream.WrapStream(s, ID, bwc)
s = mstream.WrapStream(s, bwc)
// ok give the response to our handler.
if err := msmux.SelectProtoOrFail(ID, s); err != nil {
......@@ -115,7 +117,7 @@ func (ids *IDService) RequestHandler(s inet.Stream) {
c := s.Conn()
bwc := ids.Host.GetBandwidthReporter()
s = mstream.WrapStream(s, ID, bwc)
s = mstream.WrapStream(s, bwc)
w := ggio.NewDelimitedWriter(s)
mes := pb.Identify{}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment