Commit 952df060 authored by Jeromy Johnson's avatar Jeromy Johnson Committed by GitHub
Browse files

Merge pull request #231 from libp2p/fix/drain-on-close

make sure to not drop writes on close
parents 461faf4a 860d2784
......@@ -17,9 +17,12 @@ type stream struct {
read *io.PipeReader
conn *conn
toDeliver chan *transportObject
control chan int
state int
closed chan struct{}
reset chan struct{}
close chan struct{}
closed chan struct{}
state error
protocol protocol.ID
}
......@@ -27,12 +30,6 @@ type stream struct {
var ErrReset error = errors.New("stream reset")
var ErrClosed error = errors.New("stream closed")
const (
stateOpen = iota
stateClose
stateReset
)
type transportObject struct {
msg []byte
arrivalTime time.Time
......@@ -42,7 +39,8 @@ func NewStream(w *io.PipeWriter, r *io.PipeReader) *stream {
s := &stream{
read: r,
write: w,
control: make(chan int),
reset: make(chan struct{}, 1),
close: make(chan struct{}, 1),
closed: make(chan struct{}),
toDeliver: make(chan *transportObject),
}
......@@ -58,12 +56,7 @@ func (s *stream) Write(p []byte) (n int, err error) {
t := time.Now().Add(delay)
select {
case <-s.closed: // bail out if we're closing.
switch s.state {
case stateReset:
return 0, ErrReset
case stateClose:
return 0, ErrClosed
}
return 0, s.state
case s.toDeliver <- &transportObject{msg: p, arrivalTime: t}:
}
return len(p), nil
......@@ -79,31 +72,29 @@ func (s *stream) SetProtocol(proto protocol.ID) {
func (s *stream) Close() error {
select {
case s.control <- stateClose:
case <-s.closed:
case s.close <- struct{}{}:
default:
}
<-s.closed
if s.state == stateReset {
return nil
} else {
return ErrClosed
if s.state != ErrClosed {
return s.state
}
return nil
}
func (s *stream) Reset() error {
// Cancel any pending reads.
// Cancel any pending writes.
s.write.Close()
select {
case s.control <- stateReset:
case <-s.closed:
case s.reset <- struct{}{}:
default:
}
<-s.closed
if s.state == stateReset {
return nil
} else {
return ErrClosed
if s.state != ErrReset {
return s.state
}
return nil
}
func (s *stream) teardown() {
......@@ -128,11 +119,11 @@ func (s *stream) SetDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (p *stream) SetReadDeadline(t time.Time) error {
func (s *stream) SetReadDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (p *stream) SetWriteDeadline(t time.Time) error {
func (s *stream) SetWriteDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
......@@ -197,7 +188,8 @@ func (s *stream) transport() {
if buffered >= bufsize {
select {
case <-timer.C:
case s.state = <-s.control:
case <-s.reset:
s.reset <- struct{}{}
return
}
drainBuf()
......@@ -212,25 +204,25 @@ func (s *stream) transport() {
}
for {
switch s.state {
case stateClose:
drainBuf()
return
case stateReset:
// Reset takes precedent.
select {
case <-s.reset:
s.state = ErrReset
s.read.CloseWithError(ErrReset)
return
default:
panic("invalid state")
case stateOpen:
}
select {
case s.state = <-s.control:
continue
case o, ok := <-s.toDeliver:
if !ok {
return
}
case <-s.reset:
s.state = ErrReset
s.read.CloseWithError(ErrReset)
return
case <-s.close:
s.state = ErrClosed
drainBuf()
return
case o := <-s.toDeliver:
deliverOrWait(o)
case <-timer.C: // ok, due to write it out.
drainBuf()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment