Commit 7a3394b1 authored by Jeromy's avatar Jeromy Committed by Jeromy Johnson
Browse files

host: remember which protocols work for a given peer

parent d8468400
...@@ -97,7 +97,7 @@ func main() { ...@@ -97,7 +97,7 @@ func main() {
log.Println("opening stream...") log.Println("opening stream...")
// make a new stream from host B to host A // make a new stream from host B to host A
// it should be handled on host A by the handler we set // it should be handled on host A by the handler we set
s, err := ha.NewStream(context.Background(), "/hello/1.0.0", a.ID()) s, err := ha.NewStream(context.Background(), a.ID(), "/hello/1.0.0")
if err != nil { if err != nil {
log.Fatalln(err) log.Fatalln(err)
} }
......
...@@ -2,6 +2,7 @@ package basichost ...@@ -2,6 +2,7 @@ package basichost
import ( import (
"io" "io"
"sync"
peer "github.com/ipfs/go-libp2p-peer" peer "github.com/ipfs/go-libp2p-peer"
pstore "github.com/ipfs/go-libp2p-peerstore" pstore "github.com/ipfs/go-libp2p-peerstore"
...@@ -46,6 +47,9 @@ type BasicHost struct { ...@@ -46,6 +47,9 @@ type BasicHost struct {
relay *relay.RelayService relay *relay.RelayService
natmgr *natManager natmgr *natManager
protoPrefs map[peer.ID]map[protocol.ID]struct{}
prefsLk sync.Mutex
proc goprocess.Process proc goprocess.Process
bwc metrics.Reporter bwc metrics.Reporter
...@@ -57,6 +61,7 @@ func New(net inet.Network, opts ...interface{}) *BasicHost { ...@@ -57,6 +61,7 @@ func New(net inet.Network, opts ...interface{}) *BasicHost {
network: net, network: net,
mux: msmux.NewMultistreamMuxer(), mux: msmux.NewMultistreamMuxer(),
bwc: metrics.NewBandwidthCounter(), bwc: metrics.NewBandwidthCounter(),
protoPrefs: make(map[peer.ID]map[protocol.ID]struct{}),
} }
h.proc = goprocess.WithTeardown(func() error { h.proc = goprocess.WithTeardown(func() error {
...@@ -176,7 +181,58 @@ func (h *BasicHost) RemoveStreamHandler(pid protocol.ID) { ...@@ -176,7 +181,58 @@ func (h *BasicHost) RemoveStreamHandler(pid protocol.ID) {
// header with given protocol.ID. If there is no connection to p, attempts // header with given protocol.ID. If there is no connection to p, attempts
// to create one. If ProtocolID is "", writes no header. // to create one. If ProtocolID is "", writes no header.
// (Threadsafe) // (Threadsafe)
func (h *BasicHost) NewStream(ctx context.Context, pid protocol.ID, p peer.ID) (inet.Stream, error) { func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (inet.Stream, error) {
pref := h.preferredProtocol(p, pids)
if pref != "" {
return h.newStream(ctx, p, pref)
}
var lastErr error
for _, pid := range pids {
s, err := h.newStream(ctx, p, pid)
if err == nil {
h.setPreferredProtocol(p, pid)
return s, nil
}
lastErr = err
log.Infof("NewStream to %s for %s failed: %s", p, pid, err)
}
return nil, lastErr
}
func (h *BasicHost) preferredProtocol(p peer.ID, pids []protocol.ID) protocol.ID {
h.prefsLk.Lock()
defer h.prefsLk.Unlock()
prefs, ok := h.protoPrefs[p]
if !ok {
return ""
}
for _, pid := range pids {
if _, ok := prefs[pid]; ok {
return pid
}
}
return ""
}
func (h *BasicHost) setPreferredProtocol(p peer.ID, proto protocol.ID) {
h.prefsLk.Lock()
defer h.prefsLk.Unlock()
prefs, ok := h.protoPrefs[p]
if !ok {
prefs = make(map[protocol.ID]struct{})
h.protoPrefs[p] = prefs
}
prefs[proto] = struct{}{}
}
func (h *BasicHost) newStream(ctx context.Context, p peer.ID, pid protocol.ID) (inet.Stream, error) {
s, err := h.Network().NewStream(ctx, p) s, err := h.Network().NewStream(ctx, p)
if err != nil { if err != nil {
return nil, err return nil, err
......
...@@ -32,7 +32,7 @@ func TestHostSimple(t *testing.T) { ...@@ -32,7 +32,7 @@ func TestHostSimple(t *testing.T) {
io.Copy(w, s) // mirror everything io.Copy(w, s) // mirror everything
}) })
s, err := h1.NewStream(ctx, protocol.TestingID, h2pi.ID) s, err := h1.NewStream(ctx, h2pi.ID, protocol.TestingID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
...@@ -60,7 +60,7 @@ type Host interface { ...@@ -60,7 +60,7 @@ type Host interface {
// header with given protocol.ID. If there is no connection to p, attempts // header with given protocol.ID. If there is no connection to p, attempts
// to create one. If ProtocolID is "", writes no header. // to create one. If ProtocolID is "", writes no header.
// (Threadsafe) // (Threadsafe)
NewStream(ctx context.Context, pid protocol.ID, p peer.ID) (inet.Stream, error) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (inet.Stream, error)
// Close shuts down the host, its Network, and services. // Close shuts down the host, its Network, and services.
Close() error Close() error
......
...@@ -118,8 +118,8 @@ func (rh *RoutedHost) RemoveStreamHandler(pid protocol.ID) { ...@@ -118,8 +118,8 @@ func (rh *RoutedHost) RemoveStreamHandler(pid protocol.ID) {
rh.host.RemoveStreamHandler(pid) rh.host.RemoveStreamHandler(pid)
} }
func (rh *RoutedHost) NewStream(ctx context.Context, pid protocol.ID, p peer.ID) (inet.Stream, error) { func (rh *RoutedHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (inet.Stream, error) {
return rh.host.NewStream(ctx, pid, p) return rh.host.NewStream(ctx, p, pids...)
} }
func (rh *RoutedHost) Close() error { func (rh *RoutedHost) Close() error {
// no need to close IpfsRouting. we dont own it. // no need to close IpfsRouting. we dont own it.
......
...@@ -298,7 +298,7 @@ func TestStreams(t *testing.T) { ...@@ -298,7 +298,7 @@ func TestStreams(t *testing.T) {
h.SetStreamHandler(protocol.TestingID, handler) h.SetStreamHandler(protocol.TestingID, handler)
} }
s, err := hosts[0].NewStream(ctx, protocol.TestingID, hosts[1].ID()) s, err := hosts[0].NewStream(ctx, hosts[1].ID(), protocol.TestingID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -386,7 +386,7 @@ func TestStreamsStress(t *testing.T) { ...@@ -386,7 +386,7 @@ func TestStreamsStress(t *testing.T) {
defer wg.Done() defer wg.Done()
from := rand.Intn(len(hosts)) from := rand.Intn(len(hosts))
to := rand.Intn(len(hosts)) to := rand.Intn(len(hosts))
s, err := hosts[from].NewStream(ctx, protocol.TestingID, hosts[to].ID()) s, err := hosts[from].NewStream(ctx, hosts[to].ID(), protocol.TestingID)
if err != nil { if err != nil {
log.Debugf("%d (%s) %d (%s)", from, hosts[from], to, hosts[to]) log.Debugf("%d (%s) %d (%s)", from, hosts[from], to, hosts[to])
panic(err) panic(err)
...@@ -466,7 +466,7 @@ func TestAdding(t *testing.T) { ...@@ -466,7 +466,7 @@ func TestAdding(t *testing.T) {
} }
ctx := context.Background() ctx := context.Background()
s, err := h1.NewStream(ctx, protocol.TestingID, p2) s, err := h1.NewStream(ctx, p2, protocol.TestingID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -563,7 +563,7 @@ func TestLimitedStreams(t *testing.T) { ...@@ -563,7 +563,7 @@ func TestLimitedStreams(t *testing.T) {
} }
ctx := context.Background() ctx := context.Background()
s, err := hosts[0].NewStream(ctx, protocol.TestingID, hosts[1].ID()) s, err := hosts[0].NewStream(ctx, hosts[1].ID(), protocol.TestingID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
...@@ -68,7 +68,7 @@ func (p *PingService) PingHandler(s inet.Stream) { ...@@ -68,7 +68,7 @@ func (p *PingService) PingHandler(s inet.Stream) {
} }
func (ps *PingService) Ping(ctx context.Context, p peer.ID) (<-chan time.Duration, error) { func (ps *PingService) Ping(ctx context.Context, p peer.ID) (<-chan time.Duration, error) {
s, err := ps.Host.NewStream(ctx, ID, p) s, err := ps.Host.NewStream(ctx, p, ID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -123,7 +123,7 @@ func (rs *RelayService) pipeStream(src, dst peer.ID, s inet.Stream) error { ...@@ -123,7 +123,7 @@ func (rs *RelayService) pipeStream(src, dst peer.ID, s inet.Stream) error {
// for now, can only open streams to directly connected peers. // for now, can only open streams to directly connected peers.
// maybe we can do some routing later on. // maybe we can do some routing later on.
func (rs *RelayService) openStreamToPeer(ctx context.Context, p peer.ID) (inet.Stream, error) { func (rs *RelayService) openStreamToPeer(ctx context.Context, p peer.ID) (inet.Stream, error) {
return rs.host.NewStream(ctx, ID, p) return rs.host.NewStream(ctx, p, ID)
} }
func ReadHeader(r io.Reader) (src, dst peer.ID, err error) { func ReadHeader(r io.Reader) (src, dst peer.ID, err error) {
......
...@@ -49,7 +49,7 @@ func TestRelaySimple(t *testing.T) { ...@@ -49,7 +49,7 @@ func TestRelaySimple(t *testing.T) {
// ok, now we can try to relay n1--->n2--->n3. // ok, now we can try to relay n1--->n2--->n3.
log.Debug("open relay stream") log.Debug("open relay stream")
s, err := n1.NewStream(ctx, relay.ID, n2p) s, err := n1.NewStream(ctx, n2p, relay.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -144,7 +144,7 @@ func TestRelayAcrossFour(t *testing.T) { ...@@ -144,7 +144,7 @@ func TestRelayAcrossFour(t *testing.T) {
// ok, now we can try to relay n1--->n2--->n3--->n4--->n5 // ok, now we can try to relay n1--->n2--->n3--->n4--->n5
log.Debug("open relay stream") log.Debug("open relay stream")
s, err := n1.NewStream(ctx, relay.ID, n2p) s, err := n1.NewStream(ctx, n2p, relay.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -244,7 +244,7 @@ func TestRelayStress(t *testing.T) { ...@@ -244,7 +244,7 @@ func TestRelayStress(t *testing.T) {
// ok, now we can try to relay n1--->n2--->n3. // ok, now we can try to relay n1--->n2--->n3.
log.Debug("open relay stream") log.Debug("open relay stream")
s, err := n1.NewStream(ctx, relay.ID, n2p) s, err := n1.NewStream(ctx, n2p, relay.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
...@@ -83,7 +83,7 @@ a problem. ...@@ -83,7 +83,7 @@ a problem.
}() }()
for { for {
s, err = host.NewStream(context.Background(), protocol.TestingID, remote) s, err = host.NewStream(context.Background(), remote, protocol.TestingID)
if err != nil { if err != nil {
return return
} }
...@@ -285,7 +285,7 @@ func TestStBackpressureStreamWrite(t *testing.T) { ...@@ -285,7 +285,7 @@ func TestStBackpressureStreamWrite(t *testing.T) {
} }
// open a stream, from 2->1, this is our reader // open a stream, from 2->1, this is our reader
s, err := h2.NewStream(context.Background(), protocol.TestingID, h1.ID()) s, err := h2.NewStream(context.Background(), h1.ID(), protocol.TestingID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
...@@ -177,7 +177,7 @@ func SubtestConnSendDisc(t *testing.T, hosts []host.Host) { ...@@ -177,7 +177,7 @@ func SubtestConnSendDisc(t *testing.T, hosts []host.Host) {
for i := 0; i < numStreams; i++ { for i := 0; i < numStreams; i++ {
h1 := hosts[i%len(hosts)] h1 := hosts[i%len(hosts)]
h2 := hosts[(i+1)%len(hosts)] h2 := hosts[(i+1)%len(hosts)]
s, err := h1.NewStream(context.Background(), protocol.TestingID, h2.ID()) s, err := h1.NewStream(context.Background(), h2.ID(), protocol.TestingID)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment