diff --git a/p2p/net/mock/mock_stream.go b/p2p/net/mock/mock_stream.go index 2f0a4ee7122da27b0dcffaeffb762a1b1ae9eaf5..79a3834c0aeb5660ec7fc17314d21c17c3df7b68 100644 --- a/p2p/net/mock/mock_stream.go +++ b/p2p/net/mock/mock_stream.go @@ -17,9 +17,12 @@ type stream struct { read *io.PipeReader conn *conn toDeliver chan *transportObject - control chan int - state int - closed chan struct{} + + reset chan struct{} + close chan struct{} + closed chan struct{} + + state error protocol protocol.ID } @@ -27,12 +30,6 @@ type stream struct { var ErrReset error = errors.New("stream reset") var ErrClosed error = errors.New("stream closed") -const ( - stateOpen = iota - stateClose - stateReset -) - type transportObject struct { msg []byte arrivalTime time.Time @@ -42,7 +39,8 @@ func NewStream(w *io.PipeWriter, r *io.PipeReader) *stream { s := &stream{ read: r, write: w, - control: make(chan int), + reset: make(chan struct{}, 1), + close: make(chan struct{}, 1), closed: make(chan struct{}), toDeliver: make(chan *transportObject), } @@ -58,12 +56,7 @@ func (s *stream) Write(p []byte) (n int, err error) { t := time.Now().Add(delay) select { case <-s.closed: // bail out if we're closing. - switch s.state { - case stateReset: - return 0, ErrReset - case stateClose: - return 0, ErrClosed - } + return 0, s.state case s.toDeliver <- &transportObject{msg: p, arrivalTime: t}: } return len(p), nil @@ -79,31 +72,29 @@ func (s *stream) SetProtocol(proto protocol.ID) { func (s *stream) Close() error { select { - case s.control <- stateClose: - case <-s.closed: + case s.close <- struct{}{}: + default: } <-s.closed - if s.state == stateReset { - return nil - } else { - return ErrClosed + if s.state != ErrClosed { + return s.state } + return nil } func (s *stream) Reset() error { - // Cancel any pending reads. + // Cancel any pending writes. s.write.Close() select { - case s.control <- stateReset: - case <-s.closed: + case s.reset <- struct{}{}: + default: } <-s.closed - if s.state == stateReset { - return nil - } else { - return ErrClosed + if s.state != ErrReset { + return s.state } + return nil } func (s *stream) teardown() { @@ -128,11 +119,11 @@ func (s *stream) SetDeadline(t time.Time) error { return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } -func (p *stream) SetReadDeadline(t time.Time) error { +func (s *stream) SetReadDeadline(t time.Time) error { return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } -func (p *stream) SetWriteDeadline(t time.Time) error { +func (s *stream) SetWriteDeadline(t time.Time) error { return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } @@ -197,7 +188,8 @@ func (s *stream) transport() { if buffered >= bufsize { select { case <-timer.C: - case s.state = <-s.control: + case <-s.reset: + s.reset <- struct{}{} return } drainBuf() @@ -212,25 +204,25 @@ func (s *stream) transport() { } for { - switch s.state { - case stateClose: - drainBuf() - return - case stateReset: + // Reset takes precedent. + select { + case <-s.reset: + s.state = ErrReset s.read.CloseWithError(ErrReset) return default: - panic("invalid state") - case stateOpen: } select { - case s.state = <-s.control: - continue - case o, ok := <-s.toDeliver: - if !ok { - return - } + case <-s.reset: + s.state = ErrReset + s.read.CloseWithError(ErrReset) + return + case <-s.close: + s.state = ErrClosed + drainBuf() + return + case o := <-s.toDeliver: deliverOrWait(o) case <-timer.C: // ok, due to write it out. drainBuf()