Commit 8acc21e8 authored by Jeromy's avatar Jeromy
Browse files

Vendor in go-peerstream

parent a9de494f
package frame
import "io"
const (
headerSize = 8
)
type Header []byte
func newHeader() Header {
return Header(make([]byte, headerSize))
}
func (b Header) readFrom(d deserializer) (err error) {
// read the header
if _, err = io.ReadFull(d, []byte(b)); err != nil {
return err
}
return
}
func (b Header) Length() uint16 {
return order.Uint16(b[:2]) & lengthMask
}
func (b Header) SetLength(length int) (err error) {
if length > lengthMask || length < 0 {
return protoError("Frame length %d out of range", length)
}
order.PutUint16(b[:2], uint16(length))
return
}
func (b Header) Type() FrameType {
return FrameType((b[3]) & typeMask)
}
func (b Header) SetType(t FrameType) (err error) {
b[3] = byte(t & typeMask)
return
}
func (b Header) StreamId() StreamId {
return StreamId(order.Uint32(b[4:]) & streamMask)
}
func (b Header) SetStreamId(streamId StreamId) (err error) {
if streamId > streamMask {
return protoError("Stream id %d out of range", streamId)
}
order.PutUint32(b[4:], uint32(streamId))
return
}
func (b Header) Flags() flagsType {
return flagsType(b[2])
}
func (b Header) SetFlags(fl flagsType) (err error) {
b[2] = byte(fl)
return
}
func (b Header) Fin() bool {
return b.Flags().IsSet(flagFin)
}
func (b Header) SetAll(ftype FrameType, length int, streamId StreamId, flags flagsType) (err error) {
if err = b.SetType(ftype); err != nil {
return
}
if err = b.SetLength(length); err != nil {
return
}
if err = b.SetStreamId(streamId); err != nil {
return
}
if err = b.SetFlags(flags); err != nil {
return
}
return
}
package frame
import (
"reflect"
"testing"
)
type HeaderParams struct {
ftype FrameType
length int
streamId StreamId
flags flagsType
}
func (params *HeaderParams) checkDeserialize(t *testing.T, h Header) {
if h.Type() != params.ftype {
t.Errorf("Failed deserialization. Expected type %x, got: %x", params.ftype, h.Type())
}
if h.Length() != uint16(params.length) {
t.Errorf("Failed deserialization. Expected length %x, got: %x", params.length, h.Length())
}
if h.Flags() != params.flags {
t.Errorf("Failed deserialization. Expected flags %x, got: %x", params.flags, h.Flags())
}
if h.StreamId() != params.streamId {
t.Errorf("Failed deserialization. Expected stream id %x, got: %x", params.streamId, h.StreamId())
}
}
func TestHeaderSerialization(t *testing.T) {
t.Parallel()
testCases := []struct {
input HeaderParams
expectedOutput []byte
}{
{
HeaderParams{
ftype: TypeStreamRst,
length: 0x4,
streamId: 0x2843,
flags: 0,
},
[]byte{0, 0x4, 0, 0x2, 0, 0, 0x28, 0x43},
},
{
HeaderParams{
ftype: 0x1F,
length: 0x37BD,
streamId: 0x0,
flags: 0x9,
},
[]byte{0x37, 0xBD, 0x9, 0x1F, 0, 0, 0, 0},
},
{
HeaderParams{
ftype: 0,
length: 0,
streamId: 0,
flags: 0,
},
[]byte{0, 0, 0, 0, 0, 0, 0, 0},
},
{
HeaderParams{
ftype: typeMask,
length: lengthMask,
streamId: streamMask,
flags: flagsMask,
},
[]byte{0x3F, 0xFF, 0xFF, 0x1F, 0x7F, 0xFF, 0xFF, 0xFF},
},
{
HeaderParams{
ftype: 0x1e,
length: 0x1DAA,
streamId: 0x4F224719,
flags: 0x17,
},
[]byte{0x1D, 0xAA, 0x17, 0x1E, 0x4F, 0x22, 0x47, 0x19},
},
}
for _, test := range testCases {
var h Header = Header(make([]byte, headerSize))
h.SetAll(test.input.ftype, test.input.length, test.input.streamId, test.input.flags)
output := []byte(h)
if !reflect.DeepEqual(output, test.expectedOutput) {
t.Errorf("Failed serialization of %v. Expected %x, got: %x", test.input, output, test.expectedOutput)
}
}
}
func TestHeaderDeserialization(t *testing.T) {
t.Parallel()
testCases := []struct {
input []byte
expectedOutput HeaderParams
}{
{
[]byte{0, 0x4, 0, 0x2, 0, 0, 0x28, 0x43},
HeaderParams{
ftype: TypeStreamRst,
length: 0x4,
streamId: 0x2843,
flags: 0,
},
},
{
[]byte{0x37, 0xBD, 0x9, 0x1F, 0, 0, 0, 0},
HeaderParams{
ftype: 0x1F,
length: 0x37BD,
streamId: 0x0,
flags: 0x9,
},
},
{
[]byte{0, 0, 0, 0, 0, 0, 0, 0},
HeaderParams{
ftype: 0,
length: 0,
streamId: 0,
flags: 0,
},
},
{
[]byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF},
HeaderParams{
ftype: typeMask,
length: lengthMask,
streamId: streamMask,
flags: flagsMask,
},
},
{
[]byte{0x9D, 0xAA, 0x17, 0xF0, 0xCF, 0x22, 0x47, 0x19},
HeaderParams{
ftype: 0x10,
length: 0x1DAA,
streamId: 0x4F224719,
flags: 0x17,
},
},
}
for _, test := range testCases {
test.expectedOutput.checkDeserialize(t, Header(test.input))
}
}
func TestHeaderRoundTrip(t *testing.T) {
t.Parallel()
headers := []HeaderParams{
HeaderParams{
ftype: TypeStreamRst,
length: 0x4,
streamId: 0x2843,
flags: 0,
},
HeaderParams{
ftype: 0x1F,
length: 0x37BD,
streamId: 0x0,
flags: 0x9,
},
HeaderParams{
ftype: 0,
length: 0,
streamId: 0,
flags: 0,
},
HeaderParams{
ftype: typeMask,
length: lengthMask,
streamId: streamMask,
flags: flagsMask,
},
HeaderParams{
ftype: 0x1e,
length: 0x1DAA,
streamId: 0x4F224719,
flags: 0x17,
},
}
for _, input := range headers {
var h Header = Header(make([]byte, headerSize))
h.SetAll(input.ftype, input.length, input.streamId, input.flags)
input.checkDeserialize(t, h)
}
}
func TestValidStreamIds(t *testing.T) {
t.Parallel()
validStreamIds := []StreamId{
0x0,
0xFF,
0x23C10A8F,
0x7FFFFFFF,
}
for _, validStreamId := range validStreamIds {
var h Header = Header(make([]byte, headerSize))
err := h.SetAll(TypeStreamSyn, 0, validStreamId, 0)
if err != nil {
t.Errorf("Failed to create frame header with valid stream id %d.", validStreamId)
}
}
}
func TestInvalidStreamId(t *testing.T) {
t.Parallel()
invalidStreamIds := []StreamId{
0xF0000000,
0xB012CA8E,
0x80000000,
0xFFFFFFFF,
}
for _, invalidStreamId := range invalidStreamIds {
var h Header = Header(make([]byte, headerSize))
err := h.SetAll(TypeStreamSyn, 0, invalidStreamId, 0)
if err == nil {
t.Errorf("Failed to error on invalid stream id %d.", invalidStreamId)
}
}
}
func TestValidLengths(t *testing.T) {
t.Parallel()
validLengths := []int{
0x0,
0x2FF,
0x301A,
0x3FFF,
}
for _, validLength := range validLengths {
var h Header = Header(make([]byte, headerSize))
err := h.SetAll(TypeStreamSyn, validLength, 0, 0)
if err != nil {
t.Errorf("Failed to create frame header with valid length %d.", validLength)
}
}
}
func TestInvalidLengths(t *testing.T) {
t.Parallel()
invalidLengths := []int{
-1,
0x4000,
0xB012,
0x8000,
0xFFFF,
}
for _, invalidLength := range invalidLengths {
var h Header = Header(make([]byte, headerSize))
err := h.SetAll(TypeStreamSyn, invalidLength, 0, 0)
if err == nil {
t.Errorf("Failed to error on invalid length %d.", invalidLength)
}
}
}
package frame
import (
"io"
)
type Transport interface {
WriteFrame(WFrame) error
ReadFrame() (RFrame, error)
Close() error
}
// A frame can read and write itself to a serializer/deserializer
type RFrame interface {
StreamId() StreamId
Type() FrameType
readFrom(deserializer) error
}
type WFrame interface {
writeTo(serializer) error
}
type deserializer io.Reader
type serializer io.Writer
package frame
import "io"
const (
rstBodySize = 4
rstFrameSize = headerSize + rstBodySize
)
// RsStreamRst is a STREAM_RST frame that is read from a transport
type RStreamRst struct {
Header
body [rstBodySize]byte
}
func (f *RStreamRst) readFrom(d deserializer) (err error) {
if f.Length() != rstBodySize {
return protoError("STREAM_RST length must be %d, got %d", rstBodySize, f.Length())
}
if _, err = io.ReadFull(d, f.body[:]); err != nil {
return
}
return
}
func (f *RStreamRst) ErrorCode() ErrorCode {
return ErrorCode(order.Uint32(f.body[0:]))
}
// WStreamRst is a STREAM_RST frame that can be written, it terminate a stream ungracefully
type WStreamRst struct {
Header
all [rstFrameSize]byte
}
func NewWStreamRst() (f *WStreamRst) {
f = new(WStreamRst)
f.Header = Header(f.all[:headerSize])
return
}
func (f *WStreamRst) writeTo(s serializer) (err error) {
_, err = s.Write(f.all[:])
return
}
func (f *WStreamRst) Set(streamId StreamId, errorCode ErrorCode) (err error) {
if err = f.Header.SetAll(TypeStreamRst, rstBodySize, streamId, 0); err != nil {
return
}
if err = validRstErrorCode(errorCode); err != nil {
return
}
order.PutUint32(f.all[headerSize:], uint32(errorCode))
return
}
func validRstErrorCode(errorCode ErrorCode) error {
if errorCode >= NoSuchError {
return protoError("Invalid error code %d for STREAM_RST", errorCode)
}
return nil
}
package frame
import (
"reflect"
"testing"
)
type RstTestParams struct {
streamId StreamId
errorCode ErrorCode
}
func TestSerializeRst(t *testing.T) {
t.Parallel()
cases := []struct {
params RstTestParams
expected []byte
}{
{
RstTestParams{0x49a1bb00, ProtocolError},
[]byte{0x0, 0x4, 0x0, TypeStreamRst, 0x49, 0xa1, 0xbb, 0x00, 0x0, 0x0, 0x0, ProtocolError},
},
{
RstTestParams{0x0, FlowControlError},
[]byte{0x0, 0x4, 0x0, TypeStreamRst, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, FlowControlError},
},
{
RstTestParams{streamMask, RefusedStream},
[]byte{0x00, 0x4, 0x0, TypeStreamRst, 0x7F, 0xFF, 0xFF, 0xFF, 0x0, 0x0, 0x0, RefusedStream},
},
}
for _, tcase := range cases {
buf, trans := loadedTrans([]byte{})
var f *WStreamRst = NewWStreamRst()
if err := f.Set(tcase.params.streamId, tcase.params.errorCode); err != nil {
t.Fatalf("Error while setting params %v!", tcase.params)
}
if err := f.writeTo(trans); err != nil {
t.Fatalf("Error while writing %v!", tcase.params)
}
if !reflect.DeepEqual(tcase.expected, buf.Bytes()) {
t.Errorf("Failed to serialize STREAM_RST, expected: %v got %v", tcase.expected, buf.Bytes())
}
}
}
func TestDeserializeRst(t *testing.T) {
t.Parallel()
_, trans := loadedTrans([]byte{0x00, rstBodySize, 0x0, TypeStreamRst, 0x7F, 0xFF, 0xFF, 0xFF, 0x0, 0x0, 0x0, RefusedStream})
h := newHeader()
if err := h.readFrom(trans); err != nil {
t.Fatalf("Failed to read header: %v", err)
}
var f RStreamRst
f.Header = h
if err := f.readFrom(trans); err != nil {
t.Fatalf("Error while reading rst frame: %v", err)
}
if f.ErrorCode() != RefusedStream {
t.Errorf("Expected error code %d but got %d", RefusedStream, f.ErrorCode())
}
}
// test a bad frame length of rstBodySize+1
func TestBadLengthRst(t *testing.T) {
t.Parallel()
_, trans := loadedTrans([]byte{0x00, rstBodySize + 1, 0x0, TypeStreamRst, 0x7F, 0xFF, 0xFF, 0xFF, 0x0, 0x0, 0x0, 0x0})
h := newHeader()
if err := h.readFrom(trans); err != nil {
t.Fatalf("Failed to read header: %v", err)
}
var f RStreamRst
f.Header = h
if err := f.readFrom(trans); err == nil {
t.Errorf("Expected error when setting bad rst frame length, got none.")
}
}
// test fewer than rstBodySize bytes available after header
func TestShortReadRst(t *testing.T) {
t.Parallel()
_, trans := loadedTrans([]byte{0x00, rstBodySize, 0x0, TypeStreamRst, 0x7F, 0xFF, 0xFF, 0xFF, 0x1})
h := newHeader()
if err := h.readFrom(trans); err != nil {
t.Fatalf("Failed to read header: %v", err)
}
var f RStreamRst
f.Header = h
if err := f.readFrom(trans); err == nil {
t.Errorf("Expected error when reading incomplete frame, got none.")
}
}
package frame
import (
"fmt"
"io"
)
const (
maxSynBodySize = 8
maxSynFrameSize = headerSize + maxSynBodySize
)
type RStreamSyn struct {
Header
body [maxSynBodySize]byte
streamPriority StreamPriority
streamType StreamType
}
// StreamType returns the stream's defined type as specified by
// the remote endpoint
func (f *RStreamSyn) StreamType() StreamType {
return f.streamType
}
// StreamPriority returns the stream priority set on this frame
func (f *RStreamSyn) StreamPriority() StreamPriority {
return f.streamPriority
}
func (f *RStreamSyn) parseFields() error {
var length uint16 = 0
flags := f.Flags()
if flags.IsSet(flagStreamPriority) {
f.streamPriority = StreamPriority(order.Uint32(f.body[length : length+4]))
length += 4
} else {
f.streamPriority = 0
}
if flags.IsSet(flagStreamType) {
f.streamType = StreamType(order.Uint32(f.body[length : length+4]))
length += 4
} else {
f.streamType = 0
}
if length != f.Length() {
return fmt.Errorf("Expected length %d for flags %v, but got %v", length, flags, f.Length())
}
return nil
}
func (f *RStreamSyn) readFrom(d deserializer) (err error) {
if _, err = io.ReadFull(d, f.body[:f.Length()]); err != nil {
return
}
if err = f.parseFields(); err != nil {
return
}
return
}
type WStreamSyn struct {
Header
data [maxSynFrameSize]byte
length int
}
func (f *WStreamSyn) writeTo(s serializer) (err error) {
_, err = s.Write(f.data[:headerSize+f.Length()])
return
}
func (f *WStreamSyn) Set(streamId StreamId, streamPriority StreamPriority, streamType StreamType, fin bool) (err error) {
var (
flags flagsType
length int = 0
)
// set fin bit
if fin {
flags.Set(flagFin)
}
if streamPriority != 0 {
if streamPriority > priorityMask {
err = protoError("Priority %d is out of range", streamPriority)
return
}
flags.Set(flagStreamPriority)
start := headerSize + length
order.PutUint32(f.data[start:start+4], uint32(streamPriority))
length += 4
}
if streamType != 0 {
flags.Set(flagStreamType)
start := headerSize + length
order.PutUint32(f.data[start:start+4], uint32(streamType))
length += 4
}
// make the frame
if err = f.Header.SetAll(TypeStreamSyn, length, streamId, flags); err != nil {
return
}
return
}
func NewWStreamSyn() (f *WStreamSyn) {
f = new(WStreamSyn)
f.Header = Header(f.data[:headerSize])
return
}
package frame
import (
"encoding/binary"
"io"
)
var (
order = binary.BigEndian
)
// BasicTransport can serialize/deserialize frames on an underlying
// net.Conn to implement the muxado protocol.
type BasicTransport struct {
io.ReadWriteCloser
Header
RStreamSyn
RStreamRst
RStreamData
RStreamWndInc
RGoAway
}
// WriteFrame writes the given frame to the underlying transport
func (t *BasicTransport) WriteFrame(frame WFrame) (err error) {
// each frame knows how to write iteself to the framer
err = frame.writeTo(t)
return
}
// ReadFrame reads the next frame from the underlying transport
func (t *BasicTransport) ReadFrame() (f RFrame, err error) {
// read the header
if _, err = io.ReadFull(t, []byte(t.Header)); err != nil {
return nil, err
}
switch t.Header.Type() {
case TypeStreamSyn:
frame := &t.RStreamSyn
frame.Header = t.Header
err = frame.readFrom(t)
return frame, err
case TypeStreamRst:
frame := &t.RStreamRst
frame.Header = t.Header
err = frame.readFrom(t)
return frame, err
case TypeStreamData:
frame := &t.RStreamData
frame.Header = t.Header
err = frame.readFrom(t)
return frame, err
case TypeStreamWndInc:
frame := &t.RStreamWndInc
frame.Header = t.Header
err = frame.readFrom(t)
return frame, err
case TypeGoAway:
frame := &t.RGoAway
frame.Header = t.Header
err = frame.readFrom(t)
return frame, err
default:
return nil, protoError("Illegal frame type: %d", t.Header.Type())
}
return
}
func NewBasicTransport(rwc io.ReadWriteCloser) *BasicTransport {
trans := &BasicTransport{ReadWriteCloser: rwc, Header: make([]byte, headerSize)}
return trans
}
package frame
const (
// offsets for packing/unpacking frames
lengthOffset = 32 + 16
flagsOffset = 32 + 8
typeOffset = 32 + 3
// masks for packing/unpacking frames
lengthMask = 0x3FFF
streamMask = 0x7FFFFFFF
flagsMask = 0xFF
typeMask = 0x1F
wndIncMask = 0x7FFFFFFF
priorityMask = 0x7FFFFFFF
)
// a frameType is a 5-bit integer in the frame header that identifies the type of frame
type FrameType uint8
const (
TypeStreamSyn = 0x1
TypeStreamRst = 0x2
TypeStreamData = 0x3
TypeStreamWndInc = 0x4
TypeStreamPri = 0x5
TypeGoAway = 0x6
)
// a flagsType is an 8-bit integer containing frame-specific flag bits in the frame header
type flagsType uint8
const (
flagFin = 0x1
flagStreamPriority = 0x2
flagStreamType = 0x4
)
func (ft flagsType) IsSet(f flagsType) bool {
return (ft & f) != 0
}
func (ft *flagsType) Set(f flagsType) {
*ft |= f
}
func (ft *flagsType) Unset(f flagsType) {
*ft = *ft &^ f
}
// StreamId is 31-bit integer uniquely identifying a stream within a session
type StreamId uint32
// StreamPriority is 31-bit integer specifying a stream's priority
type StreamPriority uint32
// StreamType is 32-bit integer specifying a stream's type
type StreamType uint32
// ErrorCode is a 32-bit integer indicating a error condition included in rst/goaway frames
type ErrorCode uint32
package frame
import "io"
const (
wndIncBodySize = 4
wndIncFrameSize = headerSize + wndIncBodySize
)
// Increase a stream's flow control window size
type RStreamWndInc struct {
Header
body [wndIncBodySize]byte
}
func (f *RStreamWndInc) WindowIncrement() (inc uint32) {
return order.Uint32(f.body[:]) & wndIncMask
}
func (f *RStreamWndInc) readFrom(d deserializer) (err error) {
if f.Length() != wndIncBodySize {
return protoError("WND_INC length must be %d, got %d", wndIncBodySize, f.Length())
}
_, err = io.ReadFull(d, f.body[:])
return
}
type WStreamWndInc struct {
Header
data [wndIncFrameSize]byte
}
func (f *WStreamWndInc) writeTo(s serializer) (err error) {
_, err = s.Write(f.data[:])
return
}
func (f *WStreamWndInc) Set(streamId StreamId, inc uint32) (err error) {
if inc > wndIncMask {
return protoError("Window increment %d out of range", inc)
}
order.PutUint32(f.data[headerSize:], inc)
if err = f.Header.SetAll(TypeStreamWndInc, wndIncBodySize, streamId, 0); err != nil {
return
}
return
}
func NewWStreamWndInc() (f *WStreamWndInc) {
f = new(WStreamWndInc)
f.Header = Header(f.data[:headerSize])
return
}
package frame
import (
"reflect"
"testing"
)
type WndIncTestParams struct {
streamId StreamId
inc uint32
}
func TestSerializeWndInc(t *testing.T) {
t.Parallel()
cases := []struct {
params WndIncTestParams
expected []byte
}{
{
WndIncTestParams{0x04b1bd09, 0x0},
[]byte{0x0, 0x4, 0x0, TypeStreamWndInc, 0x04, 0xb1, 0xbd, 0x09, 0x0, 0x0, 0x0, 0x0},
},
{
WndIncTestParams{0x0, 0x12c498},
[]byte{0x0, 0x4, 0x0, TypeStreamWndInc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x12, 0xc4, 0x98},
},
{
WndIncTestParams{streamMask, wndIncMask},
[]byte{0x00, 0x4, 0x0, TypeStreamWndInc, 0x7F, 0xFF, 0xFF, 0xFF, 0x7F, 0xFF, 0xFF, 0xFF},
},
}
for _, tcase := range cases {
buf, trans := loadedTrans([]byte{})
var f *WStreamWndInc = NewWStreamWndInc()
if err := f.Set(tcase.params.streamId, tcase.params.inc); err != nil {
t.Fatalf("Error while setting params %v!", tcase.params)
}
if err := f.writeTo(trans); err != nil {
t.Fatalf("Error while writing %v!", tcase.params)
}
if !reflect.DeepEqual(tcase.expected, buf.Bytes()) {
t.Errorf("Failed to serialize STREAM_WNDINC, expected: %v got %v", tcase.expected, buf.Bytes())
}
}
}
func TestDeserializeWndInc(t *testing.T) {
t.Parallel()
_, trans := loadedTrans([]byte{0x00, wndIncBodySize, 0x0, TypeStreamWndInc, 0x7F, 0xFF, 0xFF, 0xFF, 0x0, 0x0, 0xc9, 0xF1})
h := newHeader()
if err := h.readFrom(trans); err != nil {
t.Fatalf("Failed to read header: %v", err)
}
var f RStreamWndInc
f.Header = h
if err := f.readFrom(trans); err != nil {
t.Fatalf("Error while reading rst frame: %v", err)
}
if f.WindowIncrement() != 0xc9f1 {
t.Errorf("Expected error code %d but got %d", 0xc9f1, f.WindowIncrement())
}
}
// test a bad frame length of wndIncBodySize+1
func TestBadLengthWndInc(t *testing.T) {
t.Parallel()
_, trans := loadedTrans([]byte{0x00, wndIncBodySize + 1, 0x0, TypeStreamWndInc, 0x7F, 0xFF, 0xFF, 0xFF, 0x0, 0x0, 0x0, 0x0})
h := newHeader()
if err := h.readFrom(trans); err != nil {
t.Fatalf("Failed to read header: %v", err)
}
var f RStreamWndInc
f.Header = h
if err := f.readFrom(trans); err == nil {
t.Errorf("Expected error when setting bad wndinc frame length, got none.")
}
}
// test fewer than rstBodySize bytes available after header
func TestShortReadWndInc(t *testing.T) {
t.Parallel()
_, trans := loadedTrans([]byte{0x00, wndIncBodySize, 0x0, TypeStreamWndInc, 0x7F, 0xFF, 0xFF, 0xFF, 0x1})
h := newHeader()
if err := h.readFrom(trans); err != nil {
t.Fatalf("Failed to read header: %v", err)
}
var f RStreamWndInc
f.Header = h
if err := f.readFrom(trans); err == nil {
t.Errorf("Expected error when reading incomplete frame, got none.")
}
}
package proto
import (
"QmfEm573cZeq3LpgccZMpngV6dXbm5gfU23F5nNUuhSxxJ/muxado/proto/frame"
"net"
"time"
)
type IStream interface {
Write([]byte) (int, error)
Read([]byte) (int, error)
Close() error
SetDeadline(time.Time) error
SetReadDeadline(time.Time) error
SetWriteDeadline(time.Time) error
HalfClose([]byte) (int, error)
Id() frame.StreamId
StreamType() frame.StreamType
Session() ISession
RemoteAddr() net.Addr
LocalAddr() net.Addr
}
type ISession interface {
Open() (IStream, error)
OpenStream(frame.StreamPriority, frame.StreamType, bool) (IStream, error)
Accept() (IStream, error)
Kill() error
GoAway(frame.ErrorCode, []byte) error
LocalAddr() net.Addr
RemoteAddr() net.Addr
Close() error
Wait() (frame.ErrorCode, error, []byte)
NetListener() net.Listener
NetDial(_, _ string) (net.Conn, error)
}
package proto
import (
"QmfEm573cZeq3LpgccZMpngV6dXbm5gfU23F5nNUuhSxxJ/muxado/proto/frame"
"fmt"
"io"
"net"
"reflect"
"sync"
"sync/atomic"
"time"
)
const (
defaultWindowSize = 0x10000 // 64KB
defaultAcceptQueueDepth = 100
MinExtensionType = 0xFFFFFFFF - 0x100 // 512 extensions
)
// private interface for Sessions to call Streams
type stream interface {
IStream
handleStreamData(*frame.RStreamData)
handleStreamWndInc(*frame.RStreamWndInc)
handleStreamRst(*frame.RStreamRst)
closeWith(error)
}
// for extensions
type ExtAccept func() (IStream, error)
type Extension interface {
Start(ISession, ExtAccept) frame.StreamType
}
type deadReason struct {
errorCode frame.ErrorCode
err error
remoteDebug []byte
}
// factory function that creates new streams
type streamFactory func(id frame.StreamId, priority frame.StreamPriority, streamType frame.StreamType, finLocal bool, finRemote bool, windowSize uint32, sess session) stream
// checks the parity of a stream id (local vs remote, client vs server)
type parityFn func(frame.StreamId) bool
// state for each half of the session (remote and local)
type halfState struct {
goneAway int32 // true if that half of the stream has gone away
lastId uint32 // last id used/seen from one half of the session
}
// Session implements a simple streaming session manager. It has the following characteristics:
//
// - When closing the Session, it does not linger, all pending write operations will fail immediately.
// - It completely ignores stream priority when processing and writing frames
// - It offers no customization of settings like window size/ping time
type Session struct {
conn net.Conn // connection the transport is running over
transport frame.Transport // transport
streams StreamMap // all active streams
local halfState // client state
remote halfState // server state
syn *frame.WStreamSyn // STREAM_SYN frame for opens
wr sync.Mutex // synchronization when writing frames
accept chan stream // new streams opened by the remote
diebit int32 // true if we're dying
remoteDebug []byte // debugging data sent in the remote's GoAway frame
defaultWindowSize uint32 // window size when creating new streams
newStream streamFactory // factory function to make new streams
dead chan deadReason // dead
isLocal parityFn // determines if a stream id is local or remote
exts map[frame.StreamType]chan stream // map of extension stream type -> accept channel for the extension
}
func NewSession(conn net.Conn, newStream streamFactory, isClient bool, exts []Extension) ISession {
sess := &Session{
conn: conn,
transport: frame.NewBasicTransport(conn),
streams: NewConcurrentStreamMap(),
local: halfState{lastId: 0},
remote: halfState{lastId: 0},
syn: frame.NewWStreamSyn(),
diebit: 0,
defaultWindowSize: defaultWindowSize,
accept: make(chan stream, defaultAcceptQueueDepth),
newStream: newStream,
dead: make(chan deadReason, 1), // don't block die() if there is no Wait call
exts: make(map[frame.StreamType]chan stream),
}
if isClient {
sess.isLocal = sess.isClient
sess.local.lastId += 1
} else {
sess.isLocal = sess.isServer
sess.remote.lastId += 1
}
for _, ext := range exts {
sess.startExtension(ext)
}
go sess.reader()
return sess
}
////////////////////////////////
// public interface
////////////////////////////////
func (s *Session) Open() (IStream, error) {
return s.OpenStream(0, 0, false)
}
func (s *Session) OpenStream(priority frame.StreamPriority, streamType frame.StreamType, fin bool) (ret IStream, err error) {
// check if the remote has gone away
if atomic.LoadInt32(&s.remote.goneAway) == 1 {
return nil, fmt.Errorf("Failed to create stream, remote has gone away.")
}
// this lock prevents the following race:
// goroutine1 goroutine2
// - inc stream id
// - inc stream id
// - send streamsyn
// - send streamsyn
s.wr.Lock()
// get the next id we can use
nextId := frame.StreamId(atomic.AddUint32(&s.local.lastId, 2))
// make the stream
str := s.newStream(nextId, priority, streamType, fin, false, s.defaultWindowSize, s)
// add to to the stream map
s.streams.Set(nextId, str)
// write the frame
if err = s.syn.Set(nextId, priority, streamType, fin); err != nil {
s.wr.Unlock()
s.die(frame.InternalError, err)
return
}
if err = s.transport.WriteFrame(s.syn); err != nil {
s.wr.Unlock()
s.die(frame.InternalError, err)
return
}
s.wr.Unlock()
return str, nil
}
func (s *Session) Accept() (str IStream, err error) {
var ok bool
if str, ok = <-s.accept; !ok {
return nil, fmt.Errorf("Session closed")
}
return
}
func (s *Session) Kill() error {
return s.transport.Close()
}
func (s *Session) Close() error {
return s.die(frame.NoError, fmt.Errorf("Session Close()"))
}
func (s *Session) GoAway(errorCode frame.ErrorCode, debug []byte) (err error) {
if !atomic.CompareAndSwapInt32(&s.local.goneAway, 0, 1) {
return fmt.Errorf("Already sent GoAway!")
}
s.wr.Lock()
f := frame.NewWGoAway()
remoteId := frame.StreamId(atomic.LoadUint32(&s.remote.lastId))
if err = f.Set(remoteId, errorCode, debug); err != nil {
s.wr.Unlock()
s.die(frame.InternalError, err)
return
}
if err = s.transport.WriteFrame(f); err != nil {
s.wr.Unlock()
s.die(frame.InternalError, err)
return
}
s.wr.Unlock()
return
}
func (s *Session) LocalAddr() net.Addr {
return s.conn.LocalAddr()
}
func (s *Session) RemoteAddr() net.Addr {
return s.conn.RemoteAddr()
}
func (s *Session) Wait() (frame.ErrorCode, error, []byte) {
reason := <-s.dead
return reason.errorCode, reason.err, reason.remoteDebug
}
////////////////////////////////
// private interface for streams
////////////////////////////////
// removeStream removes a stream from this session's stream registry
//
// It does not error if the stream is not present
func (s *Session) removeStream(id frame.StreamId) {
s.streams.Delete(id)
return
}
// writeFrame writes the given frame to the transport and returns the error from the write operation
func (s *Session) writeFrame(f frame.WFrame, dl time.Time) (err error) {
s.wr.Lock()
s.conn.SetWriteDeadline(dl)
err = s.transport.WriteFrame(f)
s.wr.Unlock()
return
}
// die closes the session cleanly with the given error and protocol error code
func (s *Session) die(errorCode frame.ErrorCode, err error) error {
// only one shutdown ever happens
if !atomic.CompareAndSwapInt32(&s.diebit, 0, 1) {
return fmt.Errorf("Shutdown already in progress")
}
// send a go away frame
s.GoAway(errorCode, []byte(err.Error()))
// now we're safe to stop accepting incoming connections
close(s.accept)
// we cleaned up as best as possible, close the transport
s.transport.Close()
// notify all of the streams that we're closing
s.streams.Each(func(id frame.StreamId, str stream) {
str.closeWith(fmt.Errorf("Session closed"))
})
s.dead <- deadReason{errorCode, err, s.remoteDebug}
return nil
}
////////////////////////////////
// internal methods
////////////////////////////////
// reader() reads frames from the underlying transport and handles passes them to handleFrame
func (s *Session) reader() {
defer s.recoverPanic("reader()")
// close all of the extension accept channels when we're done
// we do this here instead of in die() since otherwise it wouldn't
// be safe to access s.exts
defer func() {
for _, extAccept := range s.exts {
close(extAccept)
}
}()
for {
f, err := s.transport.ReadFrame()
if err != nil {
// if we fail to read a frame, terminate the session
_, ok := err.(*frame.FramingError)
if ok {
s.die(frame.ProtocolError, err)
} else {
s.die(frame.InternalError, err)
}
return
}
s.handleFrame(f)
}
}
func (s *Session) handleFrame(rf frame.RFrame) {
switch f := rf.(type) {
case *frame.RStreamSyn:
// if we're going away, refuse new streams
if atomic.LoadInt32(&s.local.goneAway) == 1 {
rstF := frame.NewWStreamRst()
rstF.Set(f.StreamId(), frame.RefusedStream)
go s.writeFrame(rstF, time.Time{})
return
}
if f.StreamId() <= frame.StreamId(atomic.LoadUint32(&s.remote.lastId)) {
s.die(frame.ProtocolError, fmt.Errorf("Stream id %d is less than last remote id.", f.StreamId()))
return
}
if s.isLocal(f.StreamId()) {
s.die(frame.ProtocolError, fmt.Errorf("Stream id has wrong parity for remote endpoint: %d", f.StreamId()))
return
}
// update last remote id
atomic.StoreUint32(&s.remote.lastId, uint32(f.StreamId()))
// make the new stream
str := s.newStream(f.StreamId(), f.StreamPriority(), f.StreamType(), false, f.Fin(), s.defaultWindowSize, s)
// add it to the stream map
s.streams.Set(f.StreamId(), str)
// check if this is an extension stream
if f.StreamType() >= MinExtensionType {
extAccept, ok := s.exts[f.StreamType()]
if !ok {
// Extension type of stream not registered
fRst := frame.NewWStreamRst()
if err := fRst.Set(f.StreamId(), frame.StreamClosed); err != nil {
s.die(frame.InternalError, err)
}
s.wr.Lock()
defer s.wr.Unlock()
s.transport.WriteFrame(fRst)
} else {
extAccept <- str
}
return
}
// put the new stream on the accept channel
s.accept <- str
case *frame.RStreamData:
if str := s.getStream(f.StreamId()); str != nil {
str.handleStreamData(f)
} else {
// if we get a data frame on a non-existent connection, we still
// need to read out the frame body so that the stream stays in a
// good state. read the payload into a throwaway buffer
discard := make([]byte, f.Length())
io.ReadFull(f.Reader(), discard)
// DATA frames on closed connections are just stream-level errors
fRst := frame.NewWStreamRst()
if err := fRst.Set(f.StreamId(), frame.StreamClosed); err != nil {
s.die(frame.InternalError, err)
}
s.wr.Lock()
defer s.wr.Unlock()
s.transport.WriteFrame(fRst)
return
}
case *frame.RStreamRst:
// delegate to the stream to handle these frames
if str := s.getStream(f.StreamId()); str != nil {
str.handleStreamRst(f)
}
case *frame.RStreamWndInc:
// delegate to the stream to handle these frames
if str := s.getStream(f.StreamId()); str != nil {
str.handleStreamWndInc(f)
}
case *frame.RGoAway:
atomic.StoreInt32(&s.remote.goneAway, 1)
s.remoteDebug = f.Debug()
lastId := f.LastStreamId()
s.streams.Each(func(id frame.StreamId, str stream) {
// close all streams that we opened above the last handled id
if s.isLocal(str.Id()) && str.Id() > lastId {
str.closeWith(fmt.Errorf("Remote is going away"))
}
})
default:
s.die(frame.ProtocolError, fmt.Errorf("Unrecognized frame type: %v", reflect.TypeOf(f)))
return
}
}
func (s *Session) recoverPanic(prefix string) {
if r := recover(); r != nil {
s.die(frame.InternalError, fmt.Errorf("%s panic: %v", prefix, r))
}
}
func (s *Session) getStream(id frame.StreamId) (str stream) {
// decide if this id is in the "idle" state (i.e. greater than any we've seen for that parity)
var lastId *uint32
if s.isLocal(id) {
lastId = &s.local.lastId
} else {
lastId = &s.remote.lastId
}
if uint32(id) > atomic.LoadUint32(lastId) {
s.die(frame.ProtocolError, fmt.Errorf("%d is an invalid, unassigned stream id", id))
}
// find the stream in the stream map
var ok bool
if str, ok = s.streams.Get(id); !ok {
return nil
}
return
}
// check if a stream id is for a client stream. client streams are odd
func (s *Session) isClient(id frame.StreamId) bool {
return uint32(id)&1 == 1
}
func (s *Session) isServer(id frame.StreamId) bool {
return !s.isClient(id)
}
//////////////////////////////////////////////
// session extensions
//////////////////////////////////////////////
func (s *Session) startExtension(ext Extension) {
accept := make(chan stream)
extAccept := func() (IStream, error) {
s, ok := <-accept
if !ok {
return nil, fmt.Errorf("Failed to accept connection, shutting down")
}
return s, nil
}
extType := ext.Start(s, extAccept)
s.exts[extType] = accept
}
//////////////////////////////////////////////
// net adaptors
//////////////////////////////////////////////
func (s *Session) NetDial(_, _ string) (net.Conn, error) {
str, err := s.Open()
return net.Conn(str), err
}
func (s *Session) NetListener() net.Listener {
return &netListenerAdaptor{s}
}
type netListenerAdaptor struct {
*Session
}
func (a *netListenerAdaptor) Addr() net.Addr {
return a.LocalAddr()
}
func (a *netListenerAdaptor) Accept() (net.Conn, error) {
str, err := a.Session.Accept()
return net.Conn(str), err
}
package proto
import (
"QmfEm573cZeq3LpgccZMpngV6dXbm5gfU23F5nNUuhSxxJ/muxado/proto/frame"
"io"
"io/ioutil"
"net"
"net/http"
"testing"
"time"
)
func fakeStreamFactory(id frame.StreamId, priority frame.StreamPriority, streamType frame.StreamType, finLocal bool, finRemote bool, windowSize uint32, sess session) stream {
return new(fakeStream)
}
type fakeStream struct {
}
func (s *fakeStream) Write([]byte) (int, error) { return 0, nil }
func (s *fakeStream) Read([]byte) (int, error) { return 0, nil }
func (s *fakeStream) Close() error { return nil }
func (s *fakeStream) SetDeadline(time.Time) error { return nil }
func (s *fakeStream) SetReadDeadline(time.Time) error { return nil }
func (s *fakeStream) SetWriteDeadline(time.Time) error { return nil }
func (s *fakeStream) HalfClose([]byte) (int, error) { return 0, nil }
func (s *fakeStream) Id() frame.StreamId { return 0 }
func (s *fakeStream) StreamType() frame.StreamType { return 0 }
func (s *fakeStream) Session() ISession { return nil }
func (s *fakeStream) RemoteAddr() net.Addr { return nil }
func (s *fakeStream) LocalAddr() net.Addr { return nil }
func (s *fakeStream) handleStreamData(*frame.RStreamData) {}
func (s *fakeStream) handleStreamWndInc(*frame.RStreamWndInc) {}
func (s *fakeStream) handleStreamRst(*frame.RStreamRst) {}
func (s *fakeStream) closeWith(error) {}
type fakeConn struct {
in *io.PipeReader
out *io.PipeWriter
closed bool
}
func (c *fakeConn) SetDeadline(time.Time) error { return nil }
func (c *fakeConn) SetReadDeadline(time.Time) error { return nil }
func (c *fakeConn) SetWriteDeadline(time.Time) error { return nil }
func (c *fakeConn) LocalAddr() net.Addr { return nil }
func (c *fakeConn) RemoteAddr() net.Addr { return nil }
func (c *fakeConn) Close() error { c.closed = true; c.in.Close(); return c.out.Close() }
func (c *fakeConn) Read(p []byte) (int, error) { return c.in.Read(p) }
func (c *fakeConn) Write(p []byte) (int, error) { return c.out.Write(p) }
func (c *fakeConn) Discard() { go io.Copy(ioutil.Discard, c.in) }
func newFakeConnPair() (local *fakeConn, remote *fakeConn) {
local, remote = new(fakeConn), new(fakeConn)
local.in, remote.out = io.Pipe()
remote.in, local.out = io.Pipe()
return
}
func TestFailWrongClientParity(t *testing.T) {
t.Parallel()
local, remote := newFakeConnPair()
// don't need the remote output
remote.Discard()
// false for a server session
s := NewSession(local, fakeStreamFactory, false, []Extension{})
// 300 is even, and only servers send even stream ids
f := frame.NewWStreamSyn()
f.Set(300, 0, 0, false)
// send the frame into the session
trans := frame.NewBasicTransport(remote)
trans.WriteFrame(f)
// wait for failure
code, err, _ := s.Wait()
if code != frame.ProtocolError {
t.Errorf("Session not terminated with protocol error. Got %d, expected %d. Session error: %v", code, frame.ProtocolError, err)
}
if !local.closed {
t.Errorf("Session transport not closed after protocol failure.")
}
}
func TestWrongServerParity(t *testing.T) {
t.Parallel()
local, remote := newFakeConnPair()
// true for a client session
s := NewSession(local, fakeStreamFactory, true, []Extension{})
// don't need the remote output
remote.Discard()
// 300 is even, and only servers send even stream ids
f := frame.NewWStreamSyn()
f.Set(301, 0, 0, false)
// send the frame into the session
trans := frame.NewBasicTransport(remote)
trans.WriteFrame(f)
// wait for failure
code, err, _ := s.Wait()
if code != frame.ProtocolError {
t.Errorf("Session not terminated with protocol error. Got %d, expected %d. Session error: %v", code, frame.ProtocolError, err)
}
if !local.closed {
t.Errorf("Session transport not closed after protocol failure.")
}
}
func TestAcceptStream(t *testing.T) {
t.Parallel()
local, remote := newFakeConnPair()
// don't need the remote output
remote.Discard()
// true for a client session
s := NewSession(local, NewStream, true, []Extension{})
defer s.Close()
f := frame.NewWStreamSyn()
f.Set(300, 0, 0, false)
// send the frame into the session
trans := frame.NewBasicTransport(remote)
trans.WriteFrame(f)
done := make(chan int)
go func() {
defer func() { done <- 1 }()
// wait for accept
str, err := s.Accept()
if err != nil {
t.Errorf("Error accepting stream: %v", err)
return
}
if str.Id() != frame.StreamId(300) {
t.Errorf("Stream has wrong id. Expected %d, got %d", str.Id(), 300)
}
}()
select {
case <-time.After(time.Second):
t.Fatalf("Timed out!")
case <-done:
}
}
func TestSynLowId(t *testing.T) {
t.Parallel()
local, remote := newFakeConnPair()
// don't need the remote output
remote.Discard()
// true for a client session
s := NewSession(local, fakeStreamFactory, true, []Extension{})
// Start a stream
f := frame.NewWStreamSyn()
f.Set(302, 0, 0, false)
// send the frame into the session
trans := frame.NewBasicTransport(remote)
trans.WriteFrame(f)
// accept it
s.Accept()
// Start a closed stream at a lower id number
f.Set(300, 0, 0, false)
// send the frame into the session
trans.WriteFrame(f)
code, err, _ := s.Wait()
if code != frame.ProtocolError {
t.Errorf("Session not terminated with protocol error, got %d expected %d. Error: %v", code, frame.ProtocolError, err)
}
}
// Check that sending a frame of the wrong size responds with FRAME_SIZE_ERROR
func TestFrameSizeError(t *testing.T) {
}
// Check that we get a protocol error for sending STREAM_DATA on a stream id that was never opened
func TestDataOnClosed(t *testing.T) {
}
// Check that we get nothing for sending STREAM_WND_INC on a stream id that was never opened
func TestWndIncOnClosed(t *testing.T) {
}
// Check that we get nothing for sending STREAM_RST on a stream id that was never opened
func TestRstOnClosed(t *testing.T) {
}
func TestGoAway(t *testing.T) {
}
func TestCloseGoAway(t *testing.T) {
}
func TestKill(t *testing.T) {
}
// make sure we get a valid syn frame from opening a new stream
func TestOpen(t *testing.T) {
}
// test opening a new stream that is immediately half-closed
func TestOpenWithFin(t *testing.T) {
}
// validate that a session fulfills the net.Listener interface
// compile-only check
func TestNetListener(t *testing.T) {
t.Parallel()
_ = func() {
s := NewSession(new(fakeConn), NewStream, false, []Extension{})
http.Serve(s.NetListener(), nil)
}
}
func TestNetListenerAccept(t *testing.T) {
t.Parallel()
local, remote := newFakeConnPair()
sLocal := NewSession(local, NewStream, false, []Extension{})
sRemote := NewSession(remote, NewStream, true, []Extension{})
go func() {
_, err := sRemote.Open()
if err != nil {
t.Errorf("Failed to open stream: %v", err)
return
}
}()
l := sLocal.NetListener()
_, err := l.Accept()
if err != nil {
t.Fatalf("Failed to accept stream: %v", err)
}
}
// set up a fake extension which tries to accept a stream.
// we're testing to make sure that when the remote side closes the connection
// that the extension actually gets an error back from its accept() method
type fakeExt struct {
closeOk chan int
}
func (e *fakeExt) Start(sess ISession, accept ExtAccept) frame.StreamType {
go func() {
_, err := accept()
if err != nil {
// we should get an error when the session close
e.closeOk <- 1
}
}()
return MinExtensionType
}
func TestExtensionCleanupAccept(t *testing.T) {
t.Parallel()
local, remote := newFakeConnPair()
closeOk := make(chan int)
_ = NewSession(local, NewStream, false, []Extension{&fakeExt{closeOk}})
sRemote := NewSession(remote, NewStream, true, []Extension{})
sRemote.Close()
select {
case <-time.After(time.Second):
t.Fatalf("Timed out!")
case <-closeOk:
}
}
func TestWriteAfterClose(t *testing.T) {
t.Parallel()
local, remote := newFakeConnPair()
sLocal := NewSession(local, NewStream, false, []Extension{})
sRemote := NewSession(remote, NewStream, true, []Extension{})
closed := make(chan int)
go func() {
stream, err := sRemote.Open()
if err != nil {
t.Errorf("Failed to open stream: %v", err)
return
}
<-closed
if _, err = stream.Write([]byte("test!")); err != nil {
t.Errorf("Failed to write test data: %v", err)
return
}
if _, err := sRemote.Open(); err != nil {
t.Errorf("Failed to open second stream: %v", err)
return
}
}()
stream, err := sLocal.Accept()
if err != nil {
t.Fatalf("Failed to accept stream!")
}
// tell the other side that we closed so they can write late
stream.Close()
closed <- 1
if _, err = sLocal.Accept(); err != nil {
t.Fatalf("Failed to accept second connection: %v", err)
}
}
package proto
import (
"QmfEm573cZeq3LpgccZMpngV6dXbm5gfU23F5nNUuhSxxJ/muxado/proto/buffer"
"QmfEm573cZeq3LpgccZMpngV6dXbm5gfU23F5nNUuhSxxJ/muxado/proto/frame"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
)
var (
zeroTime time.Time
resetRemoveDelay = 10 * time.Second
closeError = fmt.Errorf("Stream closed")
)
type Stream struct {
id frame.StreamId // stream id (const)
streamType frame.StreamType // related stream id (const)
session session // the parent session (const)
inBuffer *buffer.Inbound // buffer for data coming in from the remote side
outBuffer *buffer.Outbound // manages size of the outbound window
sentRst uint32 // == 1 only if we sent a reset to close this connection
writer sync.Mutex // only one writer at a time
wdata *frame.WStreamData // the frame this stream is currently writing
winc *frame.WStreamWndInc // window increment currently being written
readDeadline time.Time // deadline for reads (protected by buffer mutex)
writeDeadline time.Time // deadline for writes (protected by writer mutex)
}
// private interface for Streams to call Sessions
type session interface {
ISession
writeFrame(frame.WFrame, time.Time) error
die(frame.ErrorCode, error) error
removeStream(frame.StreamId)
}
////////////////////////////////
// public interface
////////////////////////////////
func NewStream(id frame.StreamId, priority frame.StreamPriority, streamType frame.StreamType, finLocal bool, finRemote bool, windowSize uint32, sess session) stream {
str := &Stream{
id: id,
inBuffer: buffer.NewInbound(int(windowSize)),
outBuffer: buffer.NewOutbound(int(windowSize)),
streamType: streamType,
session: sess,
wdata: frame.NewWStreamData(),
winc: frame.NewWStreamWndInc(),
}
if finLocal {
str.inBuffer.SetError(io.EOF)
}
if finRemote {
str.outBuffer.SetError(fmt.Errorf("Stream closed"))
}
return str
}
func (s *Stream) Write(buf []byte) (n int, err error) {
return s.write(buf, false)
}
func (s *Stream) Read(buf []byte) (n int, err error) {
// read from the buffer
n, err = s.inBuffer.Read(buf)
// if we read more than zero, we send a window update
if n > 0 {
errWnd := s.sendWindowUpdate(uint32(n))
if errWnd != nil {
err = errWnd
s.die(frame.InternalError, err)
}
}
return
}
// Close closes the stream in a manner that attempts to emulate a net.Conn's Close():
// - It calls HalfClose() with an empty buffer to half-close the stream on the remote side
// - It calls closeWith() so that all future Read/Write operations will fail
// - If the stream receives another STREAM_DATA frame from the remote side, it will send a STREAM_RST with a CANCELED error code
func (s *Stream) Close() error {
s.HalfClose([]byte{})
s.closeWith(closeError)
return nil
}
func (s *Stream) SetDeadline(deadline time.Time) (err error) {
if err = s.SetReadDeadline(deadline); err != nil {
return
}
if err = s.SetWriteDeadline(deadline); err != nil {
return
}
return
}
func (s *Stream) SetReadDeadline(dl time.Time) error {
s.inBuffer.SetDeadline(dl)
return nil
}
func (s *Stream) SetWriteDeadline(dl time.Time) error {
s.writer.Lock()
s.writeDeadline = dl
s.writer.Unlock()
return nil
}
func (s *Stream) HalfClose(buf []byte) (n int, err error) {
return s.write(buf, true)
}
func (s *Stream) Id() frame.StreamId {
return s.id
}
func (s *Stream) StreamType() frame.StreamType {
return s.streamType
}
func (s *Stream) Session() ISession {
return s.session
}
func (s *Stream) LocalAddr() net.Addr {
return s.session.LocalAddr()
}
func (s *Stream) RemoteAddr() net.Addr {
return s.session.RemoteAddr()
}
/////////////////////////////////////
// session's stream interface
/////////////////////////////////////
func (s *Stream) handleStreamData(f *frame.RStreamData) {
// skip writing for zero-length frames (typically for sending FIN)
if f.Length() > 0 {
// write the data into the buffer
if _, err := s.inBuffer.ReadFrom(f.Reader()); err != nil {
if err == buffer.FullError {
s.resetWith(frame.FlowControlError, fmt.Errorf("Flow control buffer overflowed"))
} else if err == closeError {
// We're trying to emulate net.Conn's Close() behavior where we close our side of the connection,
// and if we get any more frames from the other side, we RST it.
s.resetWith(frame.Cancel, fmt.Errorf("Stream closed"))
} else if err == buffer.AlreadyClosed {
// there was already an error set
s.resetWith(frame.StreamClosed, err)
} else {
// the transport returned some sort of IO error
s.die(frame.ProtocolError, err)
}
return
}
}
if f.Fin() {
s.inBuffer.SetError(io.EOF)
s.maybeRemove()
}
}
func (s *Stream) handleStreamRst(f *frame.RStreamRst) {
s.closeWith(fmt.Errorf("Stream reset by peer with error %d", f.ErrorCode()))
}
func (s *Stream) handleStreamWndInc(f *frame.RStreamWndInc) {
s.outBuffer.Increment(int(f.WindowIncrement()))
}
func (s *Stream) closeWith(err error) {
s.outBuffer.SetError(err)
s.inBuffer.SetError(err)
s.session.removeStream(s.id)
}
////////////////////////////////
// internal methods
////////////////////////////////
func (s *Stream) closeWithAndRemoveLater(err error) {
s.outBuffer.SetError(err)
s.inBuffer.SetError(err)
time.AfterFunc(resetRemoveDelay, func() {
s.session.removeStream(s.id)
})
}
func (s *Stream) maybeRemove() {
if buffer.BothClosed(s.inBuffer, s.outBuffer) {
s.session.removeStream(s.id)
}
}
func (s *Stream) resetWith(errorCode frame.ErrorCode, resetErr error) {
// only ever send one reset
if !atomic.CompareAndSwapUint32(&s.sentRst, 0, 1) {
return
}
// close the stream
s.closeWithAndRemoveLater(resetErr)
// make the reset frame
rst := frame.NewWStreamRst()
if err := rst.Set(s.id, errorCode); err != nil {
s.die(frame.InternalError, err)
}
// need write lock to make sure no data frames get sent after we send the reset
s.writer.Lock()
// send it
if err := s.session.writeFrame(rst, zeroTime); err != nil {
s.writer.Unlock()
s.die(frame.InternalError, err)
}
s.writer.Unlock()
}
func (s *Stream) write(buf []byte, fin bool) (n int, err error) {
// a write call can pass a buffer larger that we can send in a single frame
// only allow one writer at a time to prevent interleaving frames from concurrent writes
s.writer.Lock()
bufSize := len(buf)
bytesRemaining := bufSize
for bytesRemaining > 0 || fin {
// figure out the most we can write in a single frame
writeReqSize := min(0x3FFF, bytesRemaining)
// and then reduce that to however much is available in the window
// this blocks until window is available and may not return all that we asked for
var writeSize int
if writeSize, err = s.outBuffer.Decrement(writeReqSize); err != nil {
s.writer.Unlock()
return
}
// calculate the slice of the buffer we'll write
start, end := n, n+writeSize
// only send fin for the last frame
finBit := fin && end == bufSize
// make the frame
if err = s.wdata.Set(s.id, buf[start:end], finBit); err != nil {
s.writer.Unlock()
s.die(frame.InternalError, err)
return
}
// write the frame
if err = s.session.writeFrame(s.wdata, s.writeDeadline); err != nil {
s.writer.Unlock()
return
}
// update our counts
n += writeSize
bytesRemaining -= writeSize
if finBit {
s.outBuffer.SetError(fmt.Errorf("Stream closed"))
s.maybeRemove()
// handles the empty buffer case with fin case
fin = false
}
}
s.writer.Unlock()
return
}
// sendWindowUpdate sends a window increment frame
// with the given increment
func (s *Stream) sendWindowUpdate(inc uint32) (err error) {
// send a window update
if err = s.winc.Set(s.id, inc); err != nil {
return
}
// XXX: write this async? We can only write one at
// a time if we're not allocating new ones from the heap
if err = s.session.writeFrame(s.winc, zeroTime); err != nil {
return
}
return
}
// die is called when a protocol error occurs and the entire
// session must be destroyed.
func (s *Stream) die(errorCode frame.ErrorCode, err error) {
s.closeWith(fmt.Errorf("Stream closed on error: %v", err))
s.session.die(errorCode, err)
}
func min(n1, n2 int) int {
if n1 > n2 {
return n2
} else {
return n1
}
}
package proto
import (
"QmfEm573cZeq3LpgccZMpngV6dXbm5gfU23F5nNUuhSxxJ/muxado/proto/frame"
"sync"
)
const (
initMapCapacity = 128 // not too much extra memory wasted to avoid allocations
)
type StreamMap interface {
Get(frame.StreamId) (stream, bool)
Set(frame.StreamId, stream)
Delete(frame.StreamId)
Each(func(frame.StreamId, stream))
}
// ConcurrentStreamMap is a map of stream ids -> streams guarded by a read/write lock
type ConcurrentStreamMap struct {
sync.RWMutex
table map[frame.StreamId]stream
}
func (m *ConcurrentStreamMap) Get(id frame.StreamId) (s stream, ok bool) {
m.RLock()
s, ok = m.table[id]
m.RUnlock()
return
}
func (m *ConcurrentStreamMap) Set(id frame.StreamId, str stream) {
m.Lock()
m.table[id] = str
m.Unlock()
}
func (m *ConcurrentStreamMap) Delete(id frame.StreamId) {
m.Lock()
delete(m.table, id)
m.Unlock()
}
func (m *ConcurrentStreamMap) Each(fn func(frame.StreamId, stream)) {
m.Lock()
streams := make(map[frame.StreamId]stream, len(m.table))
for k, v := range m.table {
streams[k] = v
}
m.Unlock()
for id, str := range streams {
fn(id, str)
}
}
func NewConcurrentStreamMap() *ConcurrentStreamMap {
return &ConcurrentStreamMap{table: make(map[frame.StreamId]stream, initMapCapacity)}
}
package proto
import (
"QmfEm573cZeq3LpgccZMpngV6dXbm5gfU23F5nNUuhSxxJ/muxado/proto/frame"
"fmt"
"io"
"io/ioutil"
"testing"
"time"
)
func TestSendHalfCloseWithZeroWindow(t *testing.T) {
t.Parallel()
local, remote := newFakeConnPair()
s := NewSession(local, NewStream, false, []Extension{})
s.(*Session).defaultWindowSize = 10
go func() {
trans := frame.NewBasicTransport(remote)
f, err := trans.ReadFrame()
if err != nil {
t.Errorf("Failed to read next frame: %v", err)
return
}
_, ok := f.(*frame.RStreamSyn)
if !ok {
t.Errorf("Wrong frame type. Got %v, expected %v", f.Type(), frame.TypeStreamSyn)
return
}
f, err = trans.ReadFrame()
if err != nil {
t.Errorf("Failed to read next frame: %v", err)
return
}
fr, ok := f.(*frame.RStreamData)
if !ok {
t.Errorf("Wrong frame type. Got %v, expected %v", f.Type(), frame.TypeStreamData)
return
}
if fr.Length() != 10 {
t.Errorf("Wrong data length. Got %v, expected %d", fr.Length(), 10)
return
}
n, err := io.CopyN(ioutil.Discard, fr.Reader(), 10)
if n != 10 {
t.Errorf("Wrong read size. Got %d, expected %d", n, 10)
return
}
f, err = trans.ReadFrame()
if err != nil {
t.Errorf("Failed to read next frame: %v", err)
return
}
fr, ok = f.(*frame.RStreamData)
if !ok {
t.Errorf("Wrong frame type. Got %v, expected %v", f.Type(), frame.TypeStreamData)
return
}
if !fr.Fin() {
t.Errorf("Wrong frame flags. Expected fin flag to be set.")
return
}
trans.ReadFrame()
}()
str, err := s.Open()
if err != nil {
t.Fatalf("Failed to open stream: %v", err)
}
_, err = str.Write(make([]byte, 10))
if err != nil {
t.Fatalf("Failed to write data: %v", err)
}
_, err = str.HalfClose([]byte{})
if err != nil {
t.Fatalf("Failed to half-close with an empty buffer")
}
}
func TestDataAfterRst(t *testing.T) {
local, remote := newFakeConnPair()
_ = NewSession(local, NewStream, false, []Extension{})
trans := frame.NewBasicTransport(remote)
// make sure that we get an RST STREAM_CLOSED
done := make(chan int)
go func() {
defer func() { done <- 1 }()
f, err := trans.ReadFrame()
if err != nil {
t.Errorf("Failed to read frame sent from session: %v", err)
return
}
fr, ok := f.(*frame.RStreamRst)
if !ok {
t.Errorf("Frame is not STREAM_RST: %v", f)
return
}
if fr.ErrorCode() != frame.StreamClosed {
t.Errorf("Error code on STREAM_RST is not STREAM_CLOSED. Got %d, expected %d", fr.ErrorCode(), frame.StreamClosed)
return
}
}()
fSyn := frame.NewWStreamSyn()
if err := fSyn.Set(301, 0, 0, false); err != nil {
t.Fatalf("Failed to make syn frame: %v", err)
}
if err := trans.WriteFrame(fSyn); err != nil {
t.Fatalf("Failed to send syn: %v", err)
}
fRst := frame.NewWStreamRst()
if err := fRst.Set(301, frame.Cancel); err != nil {
t.Fatal("Failed to make rst frame: %v", err)
}
if err := trans.WriteFrame(fRst); err != nil {
t.Fatalf("Failed to write rst frame: %v", err)
}
fData := frame.NewWStreamData()
if err := fData.Set(301, []byte{0xa, 0xFF}, false); err != nil {
t.Fatalf("Failed to set data frame")
}
trans.WriteFrame(fData)
<-done
}
func TestFlowControlError(t *testing.T) {
local, remote := newFakeConnPair()
s := NewSession(local, NewStream, false, []Extension{})
s.(*Session).defaultWindowSize = 10
trans := frame.NewBasicTransport(remote)
// make sure that we get an RST FLOW_CONTROL_ERROR
done := make(chan int)
go func() {
defer func() { done <- 1 }()
f, err := trans.ReadFrame()
if err != nil {
t.Errorf("Failed to read frame sent from session: %v", err)
return
}
fr, ok := f.(*frame.RStreamRst)
if !ok {
t.Errorf("Frame is not STREAM_RST: %v", f)
return
}
if fr.ErrorCode() != frame.FlowControlError {
t.Errorf("Error code on STREAM_RST is not FLOW_CONTROL_ERROR. Got %d, expected %d", fr.ErrorCode(), frame.FlowControlError)
return
}
}()
fSyn := frame.NewWStreamSyn()
if err := fSyn.Set(301, 0, 0, false); err != nil {
t.Fatalf("Failed to make syn frame: %v", err)
}
if err := trans.WriteFrame(fSyn); err != nil {
t.Fatalf("Failed to send syn: %v", err)
}
fData := frame.NewWStreamData()
if err := fData.Set(301, make([]byte, 11), false); err != nil {
t.Fatalf("Failed to set data frame")
}
trans.WriteFrame(fData)
<-done
}
func TestTolerateLateFrameAfterRst(t *testing.T) {
local, remote := newFakeConnPair()
s := NewSession(local, NewStream, false, []Extension{})
trans := frame.NewBasicTransport(remote)
// make sure that we don't get any error on a late frame
done := make(chan int)
go func() {
defer func() { done <- 1 }()
// read syn
trans.ReadFrame()
// read rst
trans.ReadFrame()
// should block
if f, err := trans.ReadFrame(); err != nil {
t.Errorf("Error reading frame: %v", err)
} else {
t.Errorf("Got frame that we shouldn't have read: %v. Type: %v", f, f.Type())
}
}()
str, err := s.Open()
if err != nil {
t.Fatalf("failed to open stream")
}
str.(*Stream).resetWith(frame.Cancel, fmt.Errorf("cancel"))
fData := frame.NewWStreamData()
if err := fData.Set(str.Id(), []byte{0x1, 0x2, 0x3}, false); err != nil {
t.Fatalf("Failed to set data frame")
}
trans.WriteFrame(fData)
select {
case <-done:
t.Fatalf("Stream sent response to late DATA frame")
case <-time.After(1 * time.Second):
// ok
}
}
// Test that we remove a stream from the session if both sides half-close
func TestRemoveAfterHalfClose(t *testing.T) {
local, remote := newFakeConnPair()
remote.Discard()
s := NewSession(local, NewStream, false, []Extension{})
trans := frame.NewBasicTransport(remote)
// open stream
str, err := s.Open()
if err != nil {
t.Fatalf("failed to open stream")
}
// half close remote side (true means half-close)
fData := frame.NewWStreamData()
if err := fData.Set(str.Id(), []byte{0x1, 0x2, 0x3}, true); err != nil {
t.Fatalf("Failed to set data frame")
}
if err := trans.WriteFrame(fData); err != nil {
t.Fatalf("Failed to write data frame")
}
// half-close local side
str.HalfClose([]byte{0xFF, 0xFE, 0xFD, 0xFC})
// yield so the stream can process
time.Sleep(0)
// verify stream is removed
if stream, ok := s.(*Session).streams.Get(str.Id()); ok {
t.Fatalf("Expected stream %d to be removed after both sides half-closed, but found: %v!", str.Id(), stream)
}
}
// Test that we get a RST if we send a DATA frame after we send a DATA frame with a FIN
func TestDataAfterFin(t *testing.T) {
}
package muxado
import (
"QmfEm573cZeq3LpgccZMpngV6dXbm5gfU23F5nNUuhSxxJ/muxado/proto"
"QmfEm573cZeq3LpgccZMpngV6dXbm5gfU23F5nNUuhSxxJ/muxado/proto/ext"
"crypto/tls"
"net"
)
// A Listener accepts new connections from its net.Listener
// and begins muxado server connections on them.
//
// It's API is very similar to a net.Listener, but it returns
// muxado.Sessions instead of net.Conn's.
type Listener struct {
wrapped net.Listener
}
// Accept the next connection from the listener and begin
// a muxado session on it.
func (l *Listener) Accept() (Session, error) {
conn, err := l.wrapped.Accept()
if err != nil {
return nil, err
}
return Server(conn), nil
}
// Addr returns the bound address of the wrapped net.Listener
func (l *Listener) Addr() net.Addr {
return l.wrapped.Addr()
}
// Close closes the wrapped net.Listener
func (l *Listener) Close() error {
return l.wrapped.Close()
}
// Server returns a muxado server session using conn as the transport.
func Server(conn net.Conn) Session {
return &sessionAdaptor{proto.NewSession(conn, proto.NewStream, false, []proto.Extension{ext.NewDefaultHeartbeat()})}
}
// Listen binds to a network address and returns a Listener which accepts
// new connections and starts muxado server sessions on them.
func Listen(network, addr string) (*Listener, error) {
l, err := net.Listen(network, addr)
if err != nil {
return nil, err
}
return &Listener{l}, nil
}
// ListenTLS binds to a network address and accepts new TLS-encrypted connections.
// It returns a Listener which starts new muxado server sessions on the connections.
func ListenTLS(network, addr string, tlsConfig *tls.Config) (*Listener, error) {
l, err := tls.Listen(network, addr, tlsConfig)
if err != nil {
return nil, err
}
return &Listener{l}, nil
}
// NewListener creates a new muxado listener which creates new muxado server sessions
// by accepting connections from the given net.Listener
func NewListener(l net.Listener) *Listener {
return &Listener{l}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment