diff --git a/p2p/net/swarm/dial_sync.go b/p2p/net/swarm/dial_sync.go new file mode 100644 index 0000000000000000000000000000000000000000..a63c047ac8552ec598262d481827fb0491862d28 --- /dev/null +++ b/p2p/net/swarm/dial_sync.go @@ -0,0 +1,92 @@ +package swarm + +import ( + "context" + "sync" + + peer "github.com/ipfs/go-libp2p-peer" +) + +type DialFunc func(context.Context, peer.ID) (*Conn, error) + +func NewDialSync(dfn DialFunc) *DialSync { + return &DialSync{ + dials: make(map[peer.ID]*activeDial), + dialFunc: dfn, + } +} + +type DialSync struct { + dials map[peer.ID]*activeDial + dialsLk sync.Mutex + dialFunc DialFunc +} + +type activeDial struct { + id peer.ID + refCnt int + refCntLk sync.Mutex + cancel func() + + err error + conn *Conn + waitch chan struct{} + + ds *DialSync +} + +func (dr *activeDial) wait(ctx context.Context) (*Conn, error) { + defer dr.decref() + select { + case <-dr.waitch: + return dr.conn, dr.err + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (ad *activeDial) incref() { + ad.refCntLk.Lock() + defer ad.refCntLk.Unlock() + ad.refCnt++ +} + +func (ad *activeDial) decref() { + ad.refCntLk.Lock() + defer ad.refCntLk.Unlock() + ad.refCnt-- + if ad.refCnt <= 0 { + ad.cancel() + ad.ds.dialsLk.Lock() + delete(ad.ds.dials, ad.id) + ad.ds.dialsLk.Unlock() + } +} + +func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) { + ds.dialsLk.Lock() + + actd, ok := ds.dials[p] + if !ok { + ctx, cancel := context.WithCancel(context.Background()) + actd = &activeDial{ + id: p, + cancel: cancel, + waitch: make(chan struct{}), + ds: ds, + } + ds.dials[p] = actd + + go func(ctx context.Context, p peer.ID, ad *activeDial) { + ad.conn, ad.err = ds.dialFunc(ctx, p) + close(ad.waitch) + ad.cancel() + ad.waitch = nil // to ensure nobody tries reusing this + }(ctx, p, actd) + } + + actd.incref() + ds.dialsLk.Unlock() + + return actd.wait(ctx) +} diff --git a/p2p/net/swarm/dial_sync_test.go b/p2p/net/swarm/dial_sync_test.go new file mode 100644 index 0000000000000000000000000000000000000000..61a69ce38692c4105066f86efbc54e73b60169f1 --- /dev/null +++ b/p2p/net/swarm/dial_sync_test.go @@ -0,0 +1,203 @@ +package swarm + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + peer "github.com/ipfs/go-libp2p-peer" +) + +func getMockDialFunc() (DialFunc, func(), context.Context, <-chan struct{}) { + dfcalls := make(chan struct{}, 512) // buffer it large enough that we won't care + dialctx, cancel := context.WithCancel(context.Background()) + ch := make(chan struct{}) + f := func(ctx context.Context, p peer.ID) (*Conn, error) { + dfcalls <- struct{}{} + defer cancel() + select { + case <-ch: + return new(Conn), nil + case <-ctx.Done(): + return nil, ctx.Err() + } + } + + o := new(sync.Once) + + return f, func() { o.Do(func() { close(ch) }) }, dialctx, dfcalls +} + +func TestBasicDialSync(t *testing.T) { + df, done, _, callsch := getMockDialFunc() + + dsync := NewDialSync(df) + + p := peer.ID("testpeer") + + ctx := context.Background() + + finished := make(chan struct{}) + go func() { + _, err := dsync.DialLock(ctx, p) + if err != nil { + t.Error(err) + } + finished <- struct{}{} + }() + + go func() { + _, err := dsync.DialLock(ctx, p) + if err != nil { + t.Error(err) + } + finished <- struct{}{} + }() + + // short sleep just to make sure we've moved around in the scheduler + time.Sleep(time.Millisecond * 20) + done() + + <-finished + <-finished + + if len(callsch) > 1 { + t.Fatal("should only have called dial func once!") + } +} + +func TestDialSyncCancel(t *testing.T) { + df, done, _, dcall := getMockDialFunc() + + dsync := NewDialSync(df) + + p := peer.ID("testpeer") + + ctx1, cancel1 := context.WithCancel(context.Background()) + + finished := make(chan struct{}) + go func() { + _, err := dsync.DialLock(ctx1, p) + if err != ctx1.Err() { + t.Error("should have gotten context error") + } + finished <- struct{}{} + }() + + // make sure the above makes it through the wait code first + select { + case <-dcall: + case <-time.After(time.Second): + t.Fatal("timed out waiting for dial to start") + } + + // Add a second dialwait in so two actors are waiting on the same dial + go func() { + _, err := dsync.DialLock(context.Background(), p) + if err != nil { + t.Error(err) + } + finished <- struct{}{} + }() + + time.Sleep(time.Millisecond * 20) + + // cancel the first dialwait, it should not affect the second at all + cancel1() + select { + case <-finished: + case <-time.After(time.Second): + t.Fatal("timed out waiting for wait to exit") + } + + // short sleep just to make sure we've moved around in the scheduler + time.Sleep(time.Millisecond * 20) + done() + + <-finished +} + +func TestDialSyncAllCancel(t *testing.T) { + df, done, dctx, _ := getMockDialFunc() + + dsync := NewDialSync(df) + + p := peer.ID("testpeer") + + ctx1, cancel1 := context.WithCancel(context.Background()) + + finished := make(chan struct{}) + go func() { + _, err := dsync.DialLock(ctx1, p) + if err != ctx1.Err() { + t.Error("should have gotten context error") + } + finished <- struct{}{} + }() + + // Add a second dialwait in so two actors are waiting on the same dial + go func() { + _, err := dsync.DialLock(ctx1, p) + if err != ctx1.Err() { + t.Error("should have gotten context error") + } + finished <- struct{}{} + }() + + cancel1() + for i := 0; i < 2; i++ { + select { + case <-finished: + case <-time.After(time.Second): + t.Fatal("timed out waiting for wait to exit") + } + } + + // the dial should have exited now + select { + case <-dctx.Done(): + case <-time.After(time.Second): + t.Fatal("timed out waiting for dial to return") + } + + // should be able to successfully dial that peer again + done() + _, err := dsync.DialLock(context.Background(), p) + if err != nil { + t.Fatal(err) + } +} + +func TestFailFirst(t *testing.T) { + var count int + f := func(ctx context.Context, p peer.ID) (*Conn, error) { + if count > 0 { + return new(Conn), nil + } + count++ + return nil, fmt.Errorf("gophers ate the modem") + } + + ds := NewDialSync(f) + + p := peer.ID("testing") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + _, err := ds.DialLock(ctx, p) + if err == nil { + t.Fatal("expected gophers to have eaten the modem") + } + + c, err := ds.DialLock(ctx, p) + if err != nil { + t.Fatal(err) + } + + if c == nil { + t.Fatal("should have gotten a 'real' conn back") + } +}