package yamux import ( "bytes" "fmt" "io" "io/ioutil" "runtime" "sync" "testing" "time" ) type pipeConn struct { reader *io.PipeReader writer *io.PipeWriter } func (p *pipeConn) Read(b []byte) (int, error) { return p.reader.Read(b) } func (p *pipeConn) Write(b []byte) (int, error) { return p.writer.Write(b) } func (p *pipeConn) Close() error { p.reader.Close() return p.writer.Close() } func testConn() (io.ReadWriteCloser, io.ReadWriteCloser) { read1, write1 := io.Pipe() read2, write2 := io.Pipe() return &pipeConn{read1, write2}, &pipeConn{read2, write1} } func testClientServer() (*Session, *Session) { conf := DefaultConfig() conf.AcceptBacklog = 64 conf.KeepAliveInterval = 100 * time.Millisecond return testClientServerConfig(conf) } func testClientServerConfig(conf *Config) (*Session, *Session) { conn1, conn2 := testConn() client, _ := Client(conn1, conf) server, _ := Server(conn2, conf) return client, server } func TestPing(t *testing.T) { client, server := testClientServer() defer client.Close() defer server.Close() rtt, err := client.Ping() if err != nil { t.Fatalf("err: %v", err) } if rtt == 0 { t.Fatalf("bad: %v", rtt) } rtt, err = server.Ping() if err != nil { t.Fatalf("err: %v", err) } if rtt == 0 { t.Fatalf("bad: %v", rtt) } } func TestAccept(t *testing.T) { client, server := testClientServer() defer client.Close() defer server.Close() if client.NumStreams() != 0 { t.Fatalf("bad") } if server.NumStreams() != 0 { t.Fatalf("bad") } wg := &sync.WaitGroup{} wg.Add(4) go func() { defer wg.Done() stream, err := server.AcceptStream() if err != nil { t.Fatalf("err: %v", err) } if id := stream.StreamID(); id != 1 { t.Fatalf("bad: %v", id) } if err := stream.Close(); err != nil { t.Fatalf("err: %v", err) } }() go func() { defer wg.Done() stream, err := client.AcceptStream() if err != nil { t.Fatalf("err: %v", err) } if id := stream.StreamID(); id != 2 { t.Fatalf("bad: %v", id) } if err := stream.Close(); err != nil { t.Fatalf("err: %v", err) } }() go func() { defer wg.Done() stream, err := server.OpenStream() if err != nil { t.Fatalf("err: %v", err) } if id := stream.StreamID(); id != 2 { t.Fatalf("bad: %v", id) } if err := stream.Close(); err != nil { t.Fatalf("err: %v", err) } }() go func() { defer wg.Done() stream, err := client.OpenStream() if err != nil { t.Fatalf("err: %v", err) } if id := stream.StreamID(); id != 1 { t.Fatalf("bad: %v", id) } if err := stream.Close(); err != nil { t.Fatalf("err: %v", err) } }() doneCh := make(chan struct{}) go func() { wg.Wait() close(doneCh) }() select { case <-doneCh: case <-time.After(time.Second): panic("timeout") } } func TestSendData_Small(t *testing.T) { client, server := testClientServer() defer client.Close() defer server.Close() wg := &sync.WaitGroup{} wg.Add(2) go func() { defer wg.Done() stream, err := server.AcceptStream() if err != nil { t.Fatalf("err: %v", err) } if server.NumStreams() != 1 { t.Fatalf("bad") } buf := make([]byte, 4) for i := 0; i < 1000; i++ { n, err := stream.Read(buf) if err != nil { t.Fatalf("err: %v", err) } if n != 4 { t.Fatalf("short read: %d", n) } if string(buf) != "test" { t.Fatalf("bad: %s", buf) } } if err := stream.Close(); err != nil { t.Fatalf("err: %v", err) } }() go func() { defer wg.Done() stream, err := client.Open() if err != nil { t.Fatalf("err: %v", err) } if client.NumStreams() != 1 { t.Fatalf("bad") } for i := 0; i < 1000; i++ { n, err := stream.Write([]byte("test")) if err != nil { t.Fatalf("err: %v", err) } if n != 4 { t.Fatalf("short write %d", n) } } if err := stream.Close(); err != nil { t.Fatalf("err: %v", err) } }() doneCh := make(chan struct{}) go func() { wg.Wait() close(doneCh) }() select { case <-doneCh: case <-time.After(time.Second): panic("timeout") } if client.NumStreams() != 0 { t.Fatalf("bad") } if server.NumStreams() != 0 { t.Fatalf("bad") } } func TestSendData_Large(t *testing.T) { client, server := testClientServer() defer client.Close() defer server.Close() data := make([]byte, 512*1024) for idx := range data { data[idx] = byte(idx % 256) } wg := &sync.WaitGroup{} wg.Add(2) go func() { defer wg.Done() stream, err := server.AcceptStream() if err != nil { t.Fatalf("err: %v", err) } buf := make([]byte, 4*1024) for i := 0; i < 128; i++ { n, err := stream.Read(buf) if err != nil { t.Fatalf("err: %v", err) } if n != 4*1024 { t.Fatalf("short read: %d", n) } for idx := range buf { if buf[idx] != byte(idx%256) { t.Fatalf("bad: %v %v %v", i, idx, buf[idx]) } } } if err := stream.Close(); err != nil { t.Fatalf("err: %v", err) } }() go func() { defer wg.Done() stream, err := client.Open() if err != nil { t.Fatalf("err: %v", err) } n, err := stream.Write(data) if err != nil { t.Fatalf("err: %v", err) } if n != len(data) { t.Fatalf("short write %d", n) } if err := stream.Close(); err != nil { t.Fatalf("err: %v", err) } }() doneCh := make(chan struct{}) go func() { wg.Wait() close(doneCh) }() select { case <-doneCh: case <-time.After(time.Second): panic("timeout") } } func TestGoAway(t *testing.T) { client, server := testClientServer() defer client.Close() defer server.Close() if err := server.GoAway(); err != nil { t.Fatalf("err: %v", err) } _, err := client.Open() if err != ErrRemoteGoAway { t.Fatalf("err: %v", err) } } func TestManyStreams(t *testing.T) { client, server := testClientServer() defer client.Close() defer server.Close() wg := &sync.WaitGroup{} acceptor := func(i int) { defer wg.Done() stream, err := server.AcceptStream() if err != nil { t.Fatalf("err: %v", err) } defer stream.Close() buf := make([]byte, 512) for { n, err := stream.Read(buf) if err == io.EOF { return } if err != nil { t.Fatalf("err: %v", err) } if n == 0 { t.Fatalf("err: %v", err) } } } sender := func(i int) { defer wg.Done() stream, err := client.Open() if err != nil { t.Fatalf("err: %v", err) } defer stream.Close() msg := fmt.Sprintf("%08d", i) for i := 0; i < 1000; i++ { n, err := stream.Write([]byte(msg)) if err != nil { t.Fatalf("err: %v", err) } if n != len(msg) { t.Fatalf("short write %d", n) } } } for i := 0; i < 50; i++ { wg.Add(2) go acceptor(i) go sender(i) } wg.Wait() } func TestManyStreams_PingPong(t *testing.T) { client, server := testClientServer() defer client.Close() defer server.Close() wg := &sync.WaitGroup{} ping := []byte("ping") pong := []byte("pong") acceptor := func(i int) { defer wg.Done() stream, err := server.AcceptStream() if err != nil { t.Fatalf("err: %v", err) } defer stream.Close() buf := make([]byte, 4) for { n, err := stream.Read(buf) if err == io.EOF { return } if err != nil { t.Fatalf("err: %v", err) } if n != 4 { t.Fatalf("err: %v", err) } if !bytes.Equal(buf, ping) { t.Fatalf("bad: %s", buf) } n, err = stream.Write(pong) if err != nil { t.Fatalf("err: %v", err) } if n != 4 { t.Fatalf("err: %v", err) } } } sender := func(i int) { defer wg.Done() stream, err := client.Open() if err != nil { t.Fatalf("err: %v", err) } defer stream.Close() buf := make([]byte, 4) for i := 0; i < 1000; i++ { n, err := stream.Write(ping) if err != nil { t.Fatalf("err: %v", err) } if n != 4 { t.Fatalf("short write %d", n) } n, err = stream.Read(buf) if err != nil { t.Fatalf("err: %v", err) } if n != 4 { t.Fatalf("err: %v", err) } if !bytes.Equal(buf, pong) { t.Fatalf("bad: %s", buf) } } } for i := 0; i < 50; i++ { wg.Add(2) go acceptor(i) go sender(i) } wg.Wait() } func TestHalfClose(t *testing.T) { client, server := testClientServer() defer client.Close() defer server.Close() stream, err := client.Open() if err != nil { t.Fatalf("err: %v", err) } if _, err := stream.Write([]byte("a")); err != nil { t.Fatalf("err: %v", err) } stream2, err := server.Accept() if err != nil { t.Fatalf("err: %v", err) } stream2.Close() // Half close buf := make([]byte, 4) n, err := stream2.Read(buf) if err != nil { t.Fatalf("err: %v", err) } if n != 1 { t.Fatalf("bad: %v", n) } // Send more if _, err := stream.Write([]byte("bcd")); err != nil { t.Fatalf("err: %v", err) } stream.Close() // Read after close n, err = stream2.Read(buf) if err != nil { t.Fatalf("err: %v", err) } if n != 3 { t.Fatalf("bad: %v", n) } // EOF after close n, err = stream2.Read(buf) if err != io.EOF { t.Fatalf("err: %v", err) } if n != 0 { t.Fatalf("bad: %v", n) } } func TestReadDeadline(t *testing.T) { client, server := testClientServer() defer client.Close() defer server.Close() stream, err := client.Open() if err != nil { t.Fatalf("err: %v", err) } defer stream.Close() stream2, err := server.Accept() if err != nil { t.Fatalf("err: %v", err) } defer stream2.Close() if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil { t.Fatalf("err: %v", err) } buf := make([]byte, 4) if _, err := stream.Read(buf); err != ErrTimeout { t.Fatalf("err: %v", err) } } func TestWriteDeadline(t *testing.T) { client, server := testClientServer() defer client.Close() defer server.Close() stream, err := client.Open() if err != nil { t.Fatalf("err: %v", err) } defer stream.Close() stream2, err := server.Accept() if err != nil { t.Fatalf("err: %v", err) } defer stream2.Close() if err := stream.SetWriteDeadline(time.Now().Add(50 * time.Millisecond)); err != nil { t.Fatalf("err: %v", err) } buf := make([]byte, 512) for i := 0; i < int(initialStreamWindow); i++ { _, err := stream.Write(buf) if err != nil && err == ErrTimeout { return } else if err != nil { t.Fatalf("err: %v", err) } } t.Fatalf("Expected timeout") } func TestBacklogExceeded(t *testing.T) { client, server := testClientServer() defer client.Close() defer server.Close() // Fill the backlog max := client.config.AcceptBacklog for i := 0; i < max; i++ { stream, err := client.Open() if err != nil { t.Fatalf("err: %v", err) } defer stream.Close() if _, err := stream.Write([]byte("foo")); err != nil { t.Fatalf("err: %v", err) } } // Attempt to open a new stream errCh := make(chan error, 1) go func() { _, err := client.Open() errCh <- err }() // Shutdown the server go func() { time.Sleep(10 * time.Millisecond) server.Close() }() select { case err := <-errCh: if err == nil { t.Fatalf("open should fail") } case <-time.After(time.Second): t.Fatalf("timeout") } } func TestKeepAlive(t *testing.T) { client, server := testClientServer() defer client.Close() defer server.Close() time.Sleep(200 * time.Millisecond) // Ping value should increase client.pingLock.Lock() defer client.pingLock.Unlock() if client.pingID == 0 { t.Fatalf("should ping") } server.pingLock.Lock() defer server.pingLock.Unlock() if server.pingID == 0 { t.Fatalf("should ping") } } func TestLargeWindow(t *testing.T) { conf := DefaultConfig() conf.MaxStreamWindowSize *= 2 client, server := testClientServerConfig(conf) defer client.Close() defer server.Close() stream, err := client.Open() if err != nil { t.Fatalf("err: %v", err) } defer stream.Close() stream2, err := server.Accept() if err != nil { t.Fatalf("err: %v", err) } defer stream2.Close() stream.SetWriteDeadline(time.Now().Add(10 * time.Millisecond)) buf := make([]byte, conf.MaxStreamWindowSize) n, err := stream.Write(buf) if err != nil { t.Fatalf("err: %v", err) } if n != len(buf) { t.Fatalf("short write: %d", n) } } type UnlimitedReader struct{} func (u *UnlimitedReader) Read(p []byte) (int, error) { runtime.Gosched() return len(p), nil } func TestSendData_VeryLarge(t *testing.T) { client, server := testClientServer() defer client.Close() defer server.Close() var n int64 = 1 * 1024 * 1024 * 1024 var workers int = 16 wg := &sync.WaitGroup{} wg.Add(workers * 2) for i := 0; i < workers; i++ { go func() { defer wg.Done() stream, err := server.AcceptStream() if err != nil { t.Fatalf("err: %v", err) } defer stream.Close() buf := make([]byte, 4) _, err = stream.Read(buf) if err != nil { t.Fatalf("err: %v", err) } if !bytes.Equal(buf, []byte{0, 1, 2, 3}) { t.Fatalf("bad header") } recv, err := io.Copy(ioutil.Discard, stream) if err != nil { t.Fatalf("err: %v", err) } if recv != n { t.Fatalf("bad: %v", recv) } }() } for i := 0; i < workers; i++ { go func() { defer wg.Done() stream, err := client.Open() if err != nil { t.Fatalf("err: %v", err) } defer stream.Close() _, err = stream.Write([]byte{0, 1, 2, 3}) if err != nil { t.Fatalf("err: %v", err) } unlimited := &UnlimitedReader{} sent, err := io.Copy(stream, io.LimitReader(unlimited, n)) if err != nil { t.Fatalf("err: %v", err) } if sent != n { t.Fatalf("bad: %v", sent) } }() } doneCh := make(chan struct{}) go func() { wg.Wait() close(doneCh) }() select { case <-doneCh: case <-time.After(20 * time.Second): panic("timeout") } } func TestBacklogExceeded_Accept(t *testing.T) { client, server := testClientServer() defer client.Close() defer server.Close() max := 5 * client.config.AcceptBacklog go func() { for i := 0; i < max; i++ { stream, err := server.Accept() if err != nil { t.Fatalf("err: %v", err) } defer stream.Close() } }() // Fill the backlog for i := 0; i < max; i++ { stream, err := client.Open() if err != nil { t.Fatalf("err: %v", err) } defer stream.Close() if _, err := stream.Write([]byte("foo")); err != nil { t.Fatalf("err: %v", err) } } }