Commit b077d394 authored by Jeromy Johnson's avatar Jeromy Johnson Committed by GitHub
Browse files

Merge pull request #86 from libp2p/feat/semver-mstream-matching

p2p/host: expose multistream function matching and add semver func
parents 1868c67c de84a414
...@@ -156,6 +156,17 @@ func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler inet.StreamHandler ...@@ -156,6 +156,17 @@ func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler inet.StreamHandler
}) })
} }
// SetStreamHandlerMatch sets the protocol handler on the Host's Mux
// using a matching function to do protocol comparisons
func (h *BasicHost) SetStreamHandlerMatch(pid protocol.ID, m func(string) bool, handler inet.StreamHandler) {
h.Mux().AddHandlerWithFunc(string(pid), m, func(p string, rwc io.ReadWriteCloser) error {
is := rwc.(inet.Stream)
is.SetProtocol(p)
handler(is)
return nil
})
}
// RemoveStreamHandler returns .. // RemoveStreamHandler returns ..
func (h *BasicHost) RemoveStreamHandler(pid protocol.ID) { func (h *BasicHost) RemoveStreamHandler(pid protocol.ID) {
h.Mux().RemoveHandler(string(pid)) h.Mux().RemoveHandler(string(pid))
......
...@@ -48,6 +48,10 @@ type Host interface { ...@@ -48,6 +48,10 @@ type Host interface {
// (Threadsafe) // (Threadsafe)
SetStreamHandler(pid protocol.ID, handler inet.StreamHandler) SetStreamHandler(pid protocol.ID, handler inet.StreamHandler)
// SetStreamHandlerMatch sets the protocol handler on the Host's Mux
// using a matching function for protocol selection.
SetStreamHandlerMatch(protocol.ID, func(string) bool, inet.StreamHandler)
// RemoveStreamHandler removes a handler on the mux that was set by // RemoveStreamHandler removes a handler on the mux that was set by
// SetStreamHandler // SetStreamHandler
RemoveStreamHandler(pid protocol.ID) RemoveStreamHandler(pid protocol.ID)
......
package host
import (
"strings"
semver "github.com/coreos/go-semver/semver"
)
func MultistreamSemverMatcher(base string) (func(string) bool, error) {
parts := strings.Split(base, "/")
vers, err := semver.NewVersion(parts[len(parts)-1])
if err != nil {
return nil, err
}
return func(check string) bool {
chparts := strings.Split(check, "/")
if len(chparts) != len(parts) {
return false
}
for i, v := range chparts[:len(chparts)-1] {
if parts[i] != v {
return false
}
}
chvers, err := semver.NewVersion(chparts[len(chparts)-1])
if err != nil {
return false
}
return vers.Major == chvers.Major && vers.Minor >= chvers.Minor
}, nil
}
package host
import (
"testing"
)
func TestSemverMatching(t *testing.T) {
m, err := MultistreamSemverMatcher("/testing/4.3.5")
if err != nil {
t.Fatal(err)
}
cases := map[string]bool{
"/testing/4.3.0": true,
"/testing/4.3.7": true,
"/testing/4.3.5": true,
"/testing/4.2.7": true,
"/testing/4.0.0": true,
"/testing/5.0.0": false,
"/cars/dogs/4.3.5": false,
"/foo/1.0.0": false,
"": false,
"dogs": false,
"/foo": false,
"/foo/1.1.1.1": false,
}
for p, ok := range cases {
if m(p) != ok {
t.Fatalf("expected %s to be %t", p, ok)
}
}
}
...@@ -110,6 +110,10 @@ func (rh *RoutedHost) SetStreamHandler(pid protocol.ID, handler inet.StreamHandl ...@@ -110,6 +110,10 @@ func (rh *RoutedHost) SetStreamHandler(pid protocol.ID, handler inet.StreamHandl
rh.host.SetStreamHandler(pid, handler) rh.host.SetStreamHandler(pid, handler)
} }
func (rh *RoutedHost) SetStreamHandlerMatch(pid protocol.ID, m func(string) bool, handler inet.StreamHandler) {
rh.host.SetStreamHandlerMatch(pid, m, handler)
}
func (rh *RoutedHost) RemoveStreamHandler(pid protocol.ID) { func (rh *RoutedHost) RemoveStreamHandler(pid protocol.ID) {
rh.host.RemoveStreamHandler(pid) rh.host.RemoveStreamHandler(pid)
} }
...@@ -125,3 +129,5 @@ func (rh *RoutedHost) Close() error { ...@@ -125,3 +129,5 @@ func (rh *RoutedHost) Close() error {
func (rh *RoutedHost) GetBandwidthReporter() metrics.Reporter { func (rh *RoutedHost) GetBandwidthReporter() metrics.Reporter {
return rh.host.GetBandwidthReporter() return rh.host.GetBandwidthReporter()
} }
var _ (host.Host) = (*RoutedHost)(nil)
...@@ -173,6 +173,7 @@ func (ids *IDService) consumeMessage(mes *pb.Identify, c inet.Conn) { ...@@ -173,6 +173,7 @@ func (ids *IDService) consumeMessage(mes *pb.Identify, c inet.Conn) {
p := c.RemotePeer() p := c.RemotePeer()
// mes.Protocols // mes.Protocols
ids.Host.Peerstore().SetProtocols(p, mes.Protocols)
// mes.ObservedAddr // mes.ObservedAddr
ids.consumeObservedAddress(mes.GetObservedAddr(), c) ids.consumeObservedAddress(mes.GetObservedAddr(), c)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment