Unverified Commit 5ba748bd authored by Marten Seemann's avatar Marten Seemann
Browse files

implement IsClosed() for the connection

parent 7d27130e
...@@ -3,6 +3,7 @@ package libp2pquic ...@@ -3,6 +3,7 @@ package libp2pquic
import ( import (
"fmt" "fmt"
"net" "net"
"sync"
smux "github.com/jbenet/go-stream-muxer" smux "github.com/jbenet/go-stream-muxer"
tpt "github.com/libp2p/go-libp2p-transport" tpt "github.com/libp2p/go-libp2p-transport"
...@@ -12,11 +13,15 @@ import ( ...@@ -12,11 +13,15 @@ import (
) )
type quicConn struct { type quicConn struct {
mutex sync.RWMutex
sess quic.Session sess quic.Session
transport tpt.Transport transport tpt.Transport
laddr ma.Multiaddr laddr ma.Multiaddr
raddr ma.Multiaddr raddr ma.Multiaddr
closed bool
} }
var _ tpt.Conn = &quicConn{} var _ tpt.Conn = &quicConn{}
...@@ -24,7 +29,6 @@ var _ tpt.MultiStreamConn = &quicConn{} ...@@ -24,7 +29,6 @@ var _ tpt.MultiStreamConn = &quicConn{}
func newQuicConn(sess quic.Session, t tpt.Transport) (*quicConn, error) { func newQuicConn(sess quic.Session, t tpt.Transport) (*quicConn, error) {
// analogues to manet.WrapNetConn // analogues to manet.WrapNetConn
laddr, err := quicMultiAddress(sess.LocalAddr()) laddr, err := quicMultiAddress(sess.LocalAddr())
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to convert nconn.LocalAddr: %s", err) return nil, fmt.Errorf("failed to convert nconn.LocalAddr: %s", err)
...@@ -36,12 +40,16 @@ func newQuicConn(sess quic.Session, t tpt.Transport) (*quicConn, error) { ...@@ -36,12 +40,16 @@ func newQuicConn(sess quic.Session, t tpt.Transport) (*quicConn, error) {
return nil, fmt.Errorf("failed to convert nconn.RemoteAddr: %s", err) return nil, fmt.Errorf("failed to convert nconn.RemoteAddr: %s", err)
} }
return &quicConn{ c := &quicConn{
sess: sess, sess: sess,
laddr: laddr, laddr: laddr,
raddr: raddr, raddr: raddr,
transport: t, transport: t,
}, nil }
go c.watchClosed()
return c, nil
} }
func (c *quicConn) AcceptStream() (smux.Stream, error) { func (c *quicConn) AcceptStream() (smux.Stream, error) {
...@@ -76,9 +84,17 @@ func (c *quicConn) Close() error { ...@@ -76,9 +84,17 @@ func (c *quicConn) Close() error {
return c.sess.Close(nil) return c.sess.Close(nil)
} }
// TODO: implement this func (c *quicConn) watchClosed() {
c.sess.WaitUntilClosed()
c.mutex.Lock()
c.closed = true
c.mutex.Unlock()
}
func (c *quicConn) IsClosed() bool { func (c *quicConn) IsClosed() bool {
return false c.mutex.Lock()
defer c.mutex.Unlock()
return c.closed
} }
func (c *quicConn) LocalAddr() net.Addr { func (c *quicConn) LocalAddr() net.Addr {
......
...@@ -25,7 +25,8 @@ func (s *mockStream) StreamID() protocol.StreamID { return s.id } ...@@ -25,7 +25,8 @@ func (s *mockStream) StreamID() protocol.StreamID { return s.id }
var _ quic.Stream = &mockStream{} var _ quic.Stream = &mockStream{}
type mockQuicSession struct { type mockQuicSession struct {
closed bool closed bool
waitUntilClosedChan chan struct{} // close this chan to make WaitUntilClosed return
localAddr net.Addr localAddr net.Addr
remoteAddr net.Addr remoteAddr net.Addr
...@@ -49,6 +50,7 @@ func (s *mockQuicSession) OpenStreamSync() (quic.Stream, error) { ...@@ -49,6 +50,7 @@ func (s *mockQuicSession) OpenStreamSync() (quic.Stream, error) {
func (s *mockQuicSession) Close(error) error { s.closed = true; return nil } func (s *mockQuicSession) Close(error) error { s.closed = true; return nil }
func (s *mockQuicSession) LocalAddr() net.Addr { return s.localAddr } func (s *mockQuicSession) LocalAddr() net.Addr { return s.localAddr }
func (s *mockQuicSession) RemoteAddr() net.Addr { return s.remoteAddr } func (s *mockQuicSession) RemoteAddr() net.Addr { return s.remoteAddr }
func (s *mockQuicSession) WaitUntilClosed() { <-s.waitUntilClosedChan }
var _ = Describe("Conn", func() { var _ = Describe("Conn", func() {
var ( var (
...@@ -58,8 +60,9 @@ var _ = Describe("Conn", func() { ...@@ -58,8 +60,9 @@ var _ = Describe("Conn", func() {
BeforeEach(func() { BeforeEach(func() {
sess = &mockQuicSession{ sess = &mockQuicSession{
localAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}, localAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337},
remoteAddr: &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1234}, remoteAddr: &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1234},
waitUntilClosedChan: make(chan struct{}),
} }
var err error var err error
conn, err = newQuicConn(sess, nil) conn, err = newQuicConn(sess, nil)
...@@ -82,6 +85,12 @@ var _ = Describe("Conn", func() { ...@@ -82,6 +85,12 @@ var _ = Describe("Conn", func() {
Expect(sess.closed).To(BeTrue()) Expect(sess.closed).To(BeTrue())
}) })
It("says if it is closed", func() {
Consistently(func() bool { return conn.IsClosed() }).Should(BeFalse())
close(sess.waitUntilClosedChan)
Eventually(func() bool { return conn.IsClosed() }).Should(BeTrue())
})
Context("opening streams", func() { Context("opening streams", func() {
It("opens streams", func() { It("opens streams", func() {
s := &mockStream{id: 1337} s := &mockStream{id: 1337}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment