diff --git a/p2p/net/swarm/dial_sync.go b/p2p/net/swarm/dial_sync.go new file mode 100644 index 0000000000000000000000000000000000000000..a63c047ac8552ec598262d481827fb0491862d28 --- /dev/null +++ b/p2p/net/swarm/dial_sync.go @@ -0,0 +1,92 @@ +package swarm + +import ( + "context" + "sync" + + peer "github.com/ipfs/go-libp2p-peer" +) + +type DialFunc func(context.Context, peer.ID) (*Conn, error) + +func NewDialSync(dfn DialFunc) *DialSync { + return &DialSync{ + dials: make(map[peer.ID]*activeDial), + dialFunc: dfn, + } +} + +type DialSync struct { + dials map[peer.ID]*activeDial + dialsLk sync.Mutex + dialFunc DialFunc +} + +type activeDial struct { + id peer.ID + refCnt int + refCntLk sync.Mutex + cancel func() + + err error + conn *Conn + waitch chan struct{} + + ds *DialSync +} + +func (dr *activeDial) wait(ctx context.Context) (*Conn, error) { + defer dr.decref() + select { + case <-dr.waitch: + return dr.conn, dr.err + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (ad *activeDial) incref() { + ad.refCntLk.Lock() + defer ad.refCntLk.Unlock() + ad.refCnt++ +} + +func (ad *activeDial) decref() { + ad.refCntLk.Lock() + defer ad.refCntLk.Unlock() + ad.refCnt-- + if ad.refCnt <= 0 { + ad.cancel() + ad.ds.dialsLk.Lock() + delete(ad.ds.dials, ad.id) + ad.ds.dialsLk.Unlock() + } +} + +func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) { + ds.dialsLk.Lock() + + actd, ok := ds.dials[p] + if !ok { + ctx, cancel := context.WithCancel(context.Background()) + actd = &activeDial{ + id: p, + cancel: cancel, + waitch: make(chan struct{}), + ds: ds, + } + ds.dials[p] = actd + + go func(ctx context.Context, p peer.ID, ad *activeDial) { + ad.conn, ad.err = ds.dialFunc(ctx, p) + close(ad.waitch) + ad.cancel() + ad.waitch = nil // to ensure nobody tries reusing this + }(ctx, p, actd) + } + + actd.incref() + ds.dialsLk.Unlock() + + return actd.wait(ctx) +} diff --git a/p2p/net/swarm/dial_sync_test.go b/p2p/net/swarm/dial_sync_test.go new file mode 100644 index 0000000000000000000000000000000000000000..61a69ce38692c4105066f86efbc54e73b60169f1 --- /dev/null +++ b/p2p/net/swarm/dial_sync_test.go @@ -0,0 +1,203 @@ +package swarm + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + peer "github.com/ipfs/go-libp2p-peer" +) + +func getMockDialFunc() (DialFunc, func(), context.Context, <-chan struct{}) { + dfcalls := make(chan struct{}, 512) // buffer it large enough that we won't care + dialctx, cancel := context.WithCancel(context.Background()) + ch := make(chan struct{}) + f := func(ctx context.Context, p peer.ID) (*Conn, error) { + dfcalls <- struct{}{} + defer cancel() + select { + case <-ch: + return new(Conn), nil + case <-ctx.Done(): + return nil, ctx.Err() + } + } + + o := new(sync.Once) + + return f, func() { o.Do(func() { close(ch) }) }, dialctx, dfcalls +} + +func TestBasicDialSync(t *testing.T) { + df, done, _, callsch := getMockDialFunc() + + dsync := NewDialSync(df) + + p := peer.ID("testpeer") + + ctx := context.Background() + + finished := make(chan struct{}) + go func() { + _, err := dsync.DialLock(ctx, p) + if err != nil { + t.Error(err) + } + finished <- struct{}{} + }() + + go func() { + _, err := dsync.DialLock(ctx, p) + if err != nil { + t.Error(err) + } + finished <- struct{}{} + }() + + // short sleep just to make sure we've moved around in the scheduler + time.Sleep(time.Millisecond * 20) + done() + + <-finished + <-finished + + if len(callsch) > 1 { + t.Fatal("should only have called dial func once!") + } +} + +func TestDialSyncCancel(t *testing.T) { + df, done, _, dcall := getMockDialFunc() + + dsync := NewDialSync(df) + + p := peer.ID("testpeer") + + ctx1, cancel1 := context.WithCancel(context.Background()) + + finished := make(chan struct{}) + go func() { + _, err := dsync.DialLock(ctx1, p) + if err != ctx1.Err() { + t.Error("should have gotten context error") + } + finished <- struct{}{} + }() + + // make sure the above makes it through the wait code first + select { + case <-dcall: + case <-time.After(time.Second): + t.Fatal("timed out waiting for dial to start") + } + + // Add a second dialwait in so two actors are waiting on the same dial + go func() { + _, err := dsync.DialLock(context.Background(), p) + if err != nil { + t.Error(err) + } + finished <- struct{}{} + }() + + time.Sleep(time.Millisecond * 20) + + // cancel the first dialwait, it should not affect the second at all + cancel1() + select { + case <-finished: + case <-time.After(time.Second): + t.Fatal("timed out waiting for wait to exit") + } + + // short sleep just to make sure we've moved around in the scheduler + time.Sleep(time.Millisecond * 20) + done() + + <-finished +} + +func TestDialSyncAllCancel(t *testing.T) { + df, done, dctx, _ := getMockDialFunc() + + dsync := NewDialSync(df) + + p := peer.ID("testpeer") + + ctx1, cancel1 := context.WithCancel(context.Background()) + + finished := make(chan struct{}) + go func() { + _, err := dsync.DialLock(ctx1, p) + if err != ctx1.Err() { + t.Error("should have gotten context error") + } + finished <- struct{}{} + }() + + // Add a second dialwait in so two actors are waiting on the same dial + go func() { + _, err := dsync.DialLock(ctx1, p) + if err != ctx1.Err() { + t.Error("should have gotten context error") + } + finished <- struct{}{} + }() + + cancel1() + for i := 0; i < 2; i++ { + select { + case <-finished: + case <-time.After(time.Second): + t.Fatal("timed out waiting for wait to exit") + } + } + + // the dial should have exited now + select { + case <-dctx.Done(): + case <-time.After(time.Second): + t.Fatal("timed out waiting for dial to return") + } + + // should be able to successfully dial that peer again + done() + _, err := dsync.DialLock(context.Background(), p) + if err != nil { + t.Fatal(err) + } +} + +func TestFailFirst(t *testing.T) { + var count int + f := func(ctx context.Context, p peer.ID) (*Conn, error) { + if count > 0 { + return new(Conn), nil + } + count++ + return nil, fmt.Errorf("gophers ate the modem") + } + + ds := NewDialSync(f) + + p := peer.ID("testing") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + _, err := ds.DialLock(ctx, p) + if err == nil { + t.Fatal("expected gophers to have eaten the modem") + } + + c, err := ds.DialLock(ctx, p) + if err != nil { + t.Fatal(err) + } + + if c == nil { + t.Fatal("should have gotten a 'real' conn back") + } +} diff --git a/p2p/net/swarm/dial_test.go b/p2p/net/swarm/dial_test.go index cfdb996c286ccb672321fd20f8ca8c05a5bd4dd4..aeeb6ae42d48344b66034a1d6da828008c90d016 100644 --- a/p2p/net/swarm/dial_test.go +++ b/p2p/net/swarm/dial_test.go @@ -415,7 +415,6 @@ func TestDialBackoff(t *testing.T) { if !s1.backf.Backoff(s3p) { t.Error("s3 should be on backoff") } - } } diff --git a/p2p/net/swarm/limiter.go b/p2p/net/swarm/limiter.go index 4bbfb11862e8efb9fe476711a118896e48a9d2e1..c34954a5276c3bfa6d5cda40ff156ba917ffe5aa 100644 --- a/p2p/net/swarm/limiter.go +++ b/p2p/net/swarm/limiter.go @@ -1,11 +1,11 @@ package swarm import ( + "context" "sync" peer "github.com/ipfs/go-libp2p-peer" ma "github.com/jbenet/go-multiaddr" - context "golang.org/x/net/context" conn "github.com/libp2p/go-libp2p/p2p/net/conn" addrutil "github.com/libp2p/go-libp2p/p2p/net/swarm/addr" diff --git a/p2p/net/swarm/limiter_test.go b/p2p/net/swarm/limiter_test.go index f5fc18746f702fdb8ab6538cb673275f25498d43..93761a5048fa76ef9556787147720f4803e302a4 100644 --- a/p2p/net/swarm/limiter_test.go +++ b/p2p/net/swarm/limiter_test.go @@ -1,6 +1,7 @@ package swarm import ( + "context" "fmt" "math/rand" "strconv" @@ -10,7 +11,6 @@ import ( peer "github.com/ipfs/go-libp2p-peer" ma "github.com/jbenet/go-multiaddr" mafmt "github.com/whyrusleeping/mafmt" - context "golang.org/x/net/context" conn "github.com/libp2p/go-libp2p/p2p/net/conn" ) diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index abab8e8a4303cba9a7f57d291199f5447df1ad37..e36f763fb9ad66783bef2a0792312b61ec5524bf 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -3,6 +3,7 @@ package swarm import ( + "context" "fmt" "io/ioutil" "os" @@ -32,7 +33,6 @@ import ( yamux "github.com/whyrusleeping/go-smux-yamux" mafilter "github.com/whyrusleeping/multiaddr-filter" ws "github.com/whyrusleeping/ws-transport" - context "golang.org/x/net/context" ) var log = logging.Logger("swarm2") @@ -76,7 +76,7 @@ type Swarm struct { peers pstore.Peerstore connh ConnHandler - dsync dialsync + dsync *DialSync backf dialbackoff dialT time.Duration // mainly for tests @@ -134,6 +134,7 @@ func NewSwarm(ctx context.Context, listenAddrs []ma.Multiaddr, dialer: conn.NewDialer(local, peers.PrivKey(local), wrap), } + s.dsync = NewDialSync(s.doDial) s.limiter = newDialLimiter(s.dialAddr) // configure Swarm diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go index 6f82204347928ab166d2045904257fa2d1eff450..573651778bf744a4876d24f75316ebdb010cb5a5 100644 --- a/p2p/net/swarm/swarm_dial.go +++ b/p2p/net/swarm/swarm_dial.go @@ -1,6 +1,7 @@ package swarm import ( + "context" "errors" "fmt" "sync" @@ -11,7 +12,6 @@ import ( ma "github.com/jbenet/go-multiaddr" conn "github.com/libp2p/go-libp2p/p2p/net/conn" addrutil "github.com/libp2p/go-libp2p/p2p/net/swarm/addr" - context "golang.org/x/net/context" ) // Diagram of dial sync: @@ -53,78 +53,6 @@ const defaultPerPeerRateLimit = 8 // subcomponent of Dial) var DialTimeout = time.Second * 10 -// dialsync is a small object that helps manage ongoing dials. -// this way, if we receive many simultaneous dial requests, one -// can do its thing, while the rest wait. -// -// this interface is so would-be dialers can just: -// -// for { -// c := findConnectionToPeer(peer) -// if c != nil { -// return c -// } -// -// // ok, no connections. should we dial? -// if ok, wait := dialsync.Lock(peer); !ok { -// <-wait // can optionally wait -// continue -// } -// defer dialsync.Unlock(peer) -// -// c := actuallyDial(peer) -// return c -// } -// -type dialsync struct { - // ongoing is a map of tickets for the current peers being dialed. - // this way, we dont kick off N dials simultaneously. - ongoing map[peer.ID]chan struct{} - lock sync.Mutex -} - -// Lock governs the beginning of a dial attempt. -// If there are no ongoing dials, it returns true, and the client is now -// scheduled to dial. Every other goroutine that calls startDial -- with -//the same dst -- will block until client is done. The client MUST call -// ds.Unlock(p) when it is done, to unblock the other callers. -// The client is not reponsible for achieving a successful dial, only for -// reporting the end of the attempt (calling ds.Unlock(p)). -// -// see the example below `dialsync` -func (ds *dialsync) Lock(dst peer.ID) (bool, chan struct{}) { - ds.lock.Lock() - if ds.ongoing == nil { // init if not ready - ds.ongoing = make(map[peer.ID]chan struct{}) - } - wait, found := ds.ongoing[dst] - if !found { - ds.ongoing[dst] = make(chan struct{}) - } - ds.lock.Unlock() - - if found { - return false, wait - } - - // ok! you're signed up to dial! - return true, nil -} - -// Unlock releases waiters to a dial attempt. see Lock. -// if Unlock(p) is called without calling Lock(p) first, Unlock panics. -func (ds *dialsync) Unlock(dst peer.ID) { - ds.lock.Lock() - wait, found := ds.ongoing[dst] - if !found { - panic("called dialDone with no ongoing dials to peer: " + dst.Pretty()) - } - - delete(ds.ongoing, dst) // remove ongoing dial - close(wait) // release everyone else - ds.lock.Unlock() -} - // dialbackoff is a struct used to avoid over-dialing the same, dead peers. // Whenever we totally time out on a peer (all three attempts), we add them // to dialbackoff. Then, whenevers goroutines would _wait_ (dialsync), they @@ -246,8 +174,7 @@ func (s *Swarm) bestConnectionToPeer(p peer.ID) *Conn { // gatedDialAttempt is an attempt to dial a node. It is gated by the swarm's // dial synchronization systems: dialsync and dialbackoff. func (s *Swarm) gatedDialAttempt(ctx context.Context, p peer.ID) (*Conn, error) { - var logdial = lgbl.Dial("swarm", s.LocalPeer(), p, nil, nil) - defer log.EventBegin(ctx, "swarmDialAttemptSync", logdial).Done() + defer log.EventBegin(ctx, "swarmDialAttemptSync", p).Done() // check if we already have an open connection first conn := s.bestConnectionToPeer(p) @@ -255,57 +182,36 @@ func (s *Swarm) gatedDialAttempt(ctx context.Context, p peer.ID) (*Conn, error) return conn, nil } - // check if there's an ongoing dial to this peer - if ok, wait := s.dsync.Lock(p); ok { - defer s.dsync.Unlock(p) - - // if this peer has been backed off, lets get out of here - if s.backf.Backoff(p) { - log.Event(ctx, "swarmDialBackoff", logdial) - return nil, ErrDialBackoff - } - - // ok, we have been charged to dial! let's do it. - // if it succeeds, dial will add the conn to the swarm itself. - defer log.EventBegin(ctx, "swarmDialAttemptStart", logdial).Done() - ctxT, cancel := context.WithTimeout(ctx, s.dialT) - conn, err := s.dial(ctxT, p) - cancel() - log.Debugf("dial end %s", conn) - if err != nil { - log.Event(ctx, "swarmDialBackoffAdd", logdial) - s.backf.AddBackoff(p) // let others know to backoff - - // ok, we failed. try again. (if loop is done, our error is output) - return nil, fmt.Errorf("dial attempt failed: %s", err) - } - log.Event(ctx, "swarmDialBackoffClear", logdial) - s.backf.Clear(p) // okay, no longer need to backoff - return conn, nil - - } else { - // we did not dial. we must wait for someone else to dial. + // if this peer has been backed off, lets get out of here + if s.backf.Backoff(p) { + log.Event(ctx, "swarmDialBackoff", p) + return nil, ErrDialBackoff + } - // check whether we should backoff first... - if s.backf.Backoff(p) { - log.Event(ctx, "swarmDialBackoff", logdial) - return nil, ErrDialBackoff - } + return s.dsync.DialLock(ctx, p) +} - defer log.EventBegin(ctx, "swarmDialWait", logdial).Done() - select { - case <-wait: // wait for that other dial to finish. +// doDial is an ugly shim method to retain all the logging and backoff logic +// of the old dialsync code +func (s *Swarm) doDial(ctx context.Context, p peer.ID) (*Conn, error) { + var logdial = lgbl.Dial("swarm", s.LocalPeer(), p, nil, nil) + // ok, we have been charged to dial! let's do it. + // if it succeeds, dial will add the conn to the swarm itself. + defer log.EventBegin(ctx, "swarmDialAttemptStart", logdial).Done() + ctxT, cancel := context.WithTimeout(ctx, s.dialT) + conn, err := s.dial(ctxT, p) + cancel() + log.Debugf("dial end %s", conn) + if err != nil { + log.Event(ctx, "swarmDialBackoffAdd", logdial) + s.backf.AddBackoff(p) // let others know to backoff - // see if it worked, OR we got an incoming dial in the meantime... - conn := s.bestConnectionToPeer(p) - if conn != nil { - return conn, nil - } - return nil, ErrDialFailed - case <-ctx.Done(): // or we may have to bail... - return nil, ctx.Err() - } + // ok, we failed. try again. (if loop is done, our error is output) + return nil, fmt.Errorf("dial attempt failed: %s", err) } + log.Event(ctx, "swarmDialBackoffClear", logdial) + s.backf.Clear(p) // okay, no longer need to backoff + return conn, nil } // dial is the actual swarm's dial logic, gated by Dial.