Commit 8acc21e8 authored by Jeromy's avatar Jeromy
Browse files

Vendor in go-peerstream

parent a9de494f
package spdystream
import (
"fmt"
"io"
"net"
"net/http"
"sync"
"testing"
)
func configureServer() (io.Closer, string, *sync.WaitGroup) {
authenticated = true
wg := &sync.WaitGroup{}
server, listen, serverErr := runServer(wg)
if serverErr != nil {
panic(serverErr)
}
return server, listen, wg
}
func BenchmarkDial10000(b *testing.B) {
server, addr, wg := configureServer()
defer func() {
server.Close()
wg.Wait()
}()
for i := 0; i < b.N; i++ {
conn, dialErr := net.Dial("tcp", addr)
if dialErr != nil {
panic(fmt.Sprintf("Error dialing server: %s", dialErr))
}
conn.Close()
}
}
func BenchmarkDialWithSPDYStream10000(b *testing.B) {
server, addr, wg := configureServer()
defer func() {
server.Close()
wg.Wait()
}()
for i := 0; i < b.N; i++ {
conn, dialErr := net.Dial("tcp", addr)
if dialErr != nil {
b.Fatalf("Error dialing server: %s", dialErr)
}
spdyConn, spdyErr := NewConnection(conn, false)
if spdyErr != nil {
b.Fatalf("Error creating spdy connection: %s", spdyErr)
}
go spdyConn.Serve(NoOpStreamHandler)
closeErr := spdyConn.Close()
if closeErr != nil {
b.Fatalf("Error closing connection: %s, closeErr")
}
}
}
func benchmarkStreamWithDataAndSize(size uint64, b *testing.B) {
server, addr, wg := configureServer()
defer func() {
server.Close()
wg.Wait()
}()
for i := 0; i < b.N; i++ {
conn, dialErr := net.Dial("tcp", addr)
if dialErr != nil {
b.Fatalf("Error dialing server: %s", dialErr)
}
spdyConn, spdyErr := NewConnection(conn, false)
if spdyErr != nil {
b.Fatalf("Error creating spdy connection: %s", spdyErr)
}
go spdyConn.Serve(MirrorStreamHandler)
stream, err := spdyConn.CreateStream(http.Header{}, nil, false)
writer := make([]byte, size)
stream.Write(writer)
if err != nil {
panic(err)
}
reader := make([]byte, size)
stream.Read(reader)
stream.Close()
closeErr := spdyConn.Close()
if closeErr != nil {
b.Fatalf("Error closing connection: %s, closeErr")
}
}
}
func BenchmarkStreamWith1Byte10000(b *testing.B) { benchmarkStreamWithDataAndSize(1, b) }
func BenchmarkStreamWith1KiloByte10000(b *testing.B) { benchmarkStreamWithDataAndSize(1024, b) }
func BenchmarkStreamWith1Megabyte10000(b *testing.B) { benchmarkStreamWithDataAndSize(1024*1024, b) }
package spdystream
import (
"bufio"
"bytes"
"io"
"net"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
)
func TestSpdyStreams(t *testing.T) {
var wg sync.WaitGroup
server, listen, serverErr := runServer(&wg)
if serverErr != nil {
t.Fatalf("Error initializing server: %s", serverErr)
}
conn, dialErr := net.Dial("tcp", listen)
if dialErr != nil {
t.Fatalf("Error dialing server: %s", dialErr)
}
spdyConn, spdyErr := NewConnection(conn, false)
if spdyErr != nil {
t.Fatalf("Error creating spdy connection: %s", spdyErr)
}
go spdyConn.Serve(NoOpStreamHandler)
authenticated = true
stream, streamErr := spdyConn.CreateStream(http.Header{}, nil, false)
if streamErr != nil {
t.Fatalf("Error creating stream: %s", streamErr)
}
waitErr := stream.Wait()
if waitErr != nil {
t.Fatalf("Error waiting for stream: %s", waitErr)
}
message := []byte("hello")
writeErr := stream.WriteData(message, false)
if writeErr != nil {
t.Fatalf("Error writing data")
}
buf := make([]byte, 10)
n, readErr := stream.Read(buf)
if readErr != nil {
t.Fatalf("Error reading data from stream: %s", readErr)
}
if n != 5 {
t.Fatalf("Unexpected number of bytes read:\nActual: %d\nExpected: 5", n)
}
if bytes.Compare(buf[:n], message) != 0 {
t.Fatalf("Did not receive expected message:\nActual: %s\nExpectd: %s", buf, message)
}
headers := http.Header{
"TestKey": []string{"TestVal"},
}
sendErr := stream.SendHeader(headers, false)
if sendErr != nil {
t.Fatalf("Error sending headers: %s", sendErr)
}
receiveHeaders, receiveErr := stream.ReceiveHeader()
if receiveErr != nil {
t.Fatalf("Error receiving headers: %s", receiveErr)
}
if len(receiveHeaders) != 1 {
t.Fatalf("Unexpected number of headers:\nActual: %d\nExpecting:%d", len(receiveHeaders), 1)
}
testVal := receiveHeaders.Get("TestKey")
if testVal != "TestVal" {
t.Fatalf("Wrong test value:\nActual: %q\nExpecting: %q", testVal, "TestVal")
}
writeErr = stream.WriteData(message, true)
if writeErr != nil {
t.Fatalf("Error writing data")
}
smallBuf := make([]byte, 3)
n, readErr = stream.Read(smallBuf)
if readErr != nil {
t.Fatalf("Error reading data from stream: %s", readErr)
}
if n != 3 {
t.Fatalf("Unexpected number of bytes read:\nActual: %d\nExpected: 3", n)
}
if bytes.Compare(smallBuf[:n], []byte("hel")) != 0 {
t.Fatalf("Did not receive expected message:\nActual: %s\nExpectd: %s", smallBuf[:n], message)
}
n, readErr = stream.Read(smallBuf)
if readErr != nil {
t.Fatalf("Error reading data from stream: %s", readErr)
}
if n != 2 {
t.Fatalf("Unexpected number of bytes read:\nActual: %d\nExpected: 2", n)
}
if bytes.Compare(smallBuf[:n], []byte("lo")) != 0 {
t.Fatalf("Did not receive expected message:\nActual: %s\nExpected: lo", smallBuf[:n])
}
n, readErr = stream.Read(buf)
if readErr != io.EOF {
t.Fatalf("Expected EOF reading from finished stream, read %d bytes", n)
}
// Closing again should return error since stream is already closed
streamCloseErr := stream.Close()
if streamCloseErr == nil {
t.Fatalf("No error closing finished stream")
}
if streamCloseErr != ErrWriteClosedStream {
t.Fatalf("Unexpected error closing stream: %s", streamCloseErr)
}
streamResetErr := stream.Reset()
if streamResetErr != nil {
t.Fatalf("Error reseting stream: %s", streamResetErr)
}
authenticated = false
badStream, badStreamErr := spdyConn.CreateStream(http.Header{}, nil, false)
if badStreamErr != nil {
t.Fatalf("Error creating stream: %s", badStreamErr)
}
waitErr = badStream.Wait()
if waitErr == nil {
t.Fatalf("Did not receive error creating stream")
}
if waitErr != ErrReset {
t.Fatalf("Unexpected error creating stream: %s", waitErr)
}
streamCloseErr = badStream.Close()
if streamCloseErr == nil {
t.Fatalf("No error closing bad stream")
}
spdyCloseErr := spdyConn.Close()
if spdyCloseErr != nil {
t.Fatalf("Error closing spdy connection: %s", spdyCloseErr)
}
closeErr := server.Close()
if closeErr != nil {
t.Fatalf("Error shutting down server: %s", closeErr)
}
wg.Wait()
}
func TestPing(t *testing.T) {
var wg sync.WaitGroup
server, listen, serverErr := runServer(&wg)
if serverErr != nil {
t.Fatalf("Error initializing server: %s", serverErr)
}
conn, dialErr := net.Dial("tcp", listen)
if dialErr != nil {
t.Fatalf("Error dialing server: %s", dialErr)
}
spdyConn, spdyErr := NewConnection(conn, false)
if spdyErr != nil {
t.Fatalf("Error creating spdy connection: %s", spdyErr)
}
go spdyConn.Serve(NoOpStreamHandler)
pingTime, pingErr := spdyConn.Ping()
if pingErr != nil {
t.Fatalf("Error pinging server: %s", pingErr)
}
if pingTime == time.Duration(0) {
t.Fatalf("Expecting non-zero ping time")
}
closeErr := server.Close()
if closeErr != nil {
t.Fatalf("Error shutting down server: %s", closeErr)
}
wg.Wait()
}
func TestHalfClose(t *testing.T) {
var wg sync.WaitGroup
server, listen, serverErr := runServer(&wg)
if serverErr != nil {
t.Fatalf("Error initializing server: %s", serverErr)
}
conn, dialErr := net.Dial("tcp", listen)
if dialErr != nil {
t.Fatalf("Error dialing server: %s", dialErr)
}
spdyConn, spdyErr := NewConnection(conn, false)
if spdyErr != nil {
t.Fatalf("Error creating spdy connection: %s", spdyErr)
}
go spdyConn.Serve(NoOpStreamHandler)
authenticated = true
stream, streamErr := spdyConn.CreateStream(http.Header{}, nil, false)
if streamErr != nil {
t.Fatalf("Error creating stream: %s", streamErr)
}
waitErr := stream.Wait()
if waitErr != nil {
t.Fatalf("Error waiting for stream: %s", waitErr)
}
message := []byte("hello and will read after close")
writeErr := stream.WriteData(message, false)
if writeErr != nil {
t.Fatalf("Error writing data")
}
streamCloseErr := stream.Close()
if streamCloseErr != nil {
t.Fatalf("Error closing stream: %s", streamCloseErr)
}
buf := make([]byte, 40)
n, readErr := stream.Read(buf)
if readErr != nil {
t.Fatalf("Error reading data from stream: %s", readErr)
}
if n != 31 {
t.Fatalf("Unexpected number of bytes read:\nActual: %d\nExpected: 5", n)
}
if bytes.Compare(buf[:n], message) != 0 {
t.Fatalf("Did not receive expected message:\nActual: %s\nExpectd: %s", buf, message)
}
spdyCloseErr := spdyConn.Close()
if spdyCloseErr != nil {
t.Fatalf("Error closing spdy connection: %s", spdyCloseErr)
}
closeErr := server.Close()
if closeErr != nil {
t.Fatalf("Error shutting down server: %s", closeErr)
}
wg.Wait()
}
func TestUnexpectedRemoteConnectionClosed(t *testing.T) {
tt := []struct {
closeReceiver bool
closeSender bool
}{
{closeReceiver: true, closeSender: false},
{closeReceiver: false, closeSender: true},
{closeReceiver: false, closeSender: false},
}
for tix, tc := range tt {
listener, listenErr := net.Listen("tcp", "localhost:0")
if listenErr != nil {
t.Fatalf("Error listening: %v", listenErr)
}
var serverConn net.Conn
var connErr error
go func() {
serverConn, connErr = listener.Accept()
if connErr != nil {
t.Fatalf("Error accepting: %v", connErr)
}
serverSpdyConn, _ := NewConnection(serverConn, true)
go serverSpdyConn.Serve(func(stream *Stream) {
stream.SendReply(http.Header{}, tc.closeSender)
})
}()
conn, dialErr := net.Dial("tcp", listener.Addr().String())
if dialErr != nil {
t.Fatalf("Error dialing server: %s", dialErr)
}
spdyConn, spdyErr := NewConnection(conn, false)
if spdyErr != nil {
t.Fatalf("Error creating spdy connection: %s", spdyErr)
}
go spdyConn.Serve(NoOpStreamHandler)
authenticated = true
stream, streamErr := spdyConn.CreateStream(http.Header{}, nil, false)
if streamErr != nil {
t.Fatalf("Error creating stream: %s", streamErr)
}
waitErr := stream.Wait()
if waitErr != nil {
t.Fatalf("Error waiting for stream: %s", waitErr)
}
if tc.closeReceiver {
// make stream half closed, receive only
stream.Close()
}
streamch := make(chan error, 1)
go func() {
b := make([]byte, 1)
_, err := stream.Read(b)
streamch <- err
}()
closeErr := serverConn.Close()
if closeErr != nil {
t.Fatalf("Error shutting down server: %s", closeErr)
}
select {
case e := <-streamch:
if e == nil || e != io.EOF {
t.Fatalf("(%d) Expected to get an EOF stream error", tix)
}
}
closeErr = conn.Close()
if closeErr != nil {
t.Fatalf("Error closing client connection: %s", closeErr)
}
listenErr = listener.Close()
if listenErr != nil {
t.Fatalf("Error closing listener: %s", listenErr)
}
}
}
func TestCloseNotification(t *testing.T) {
listener, listenErr := net.Listen("tcp", "localhost:0")
if listenErr != nil {
t.Fatalf("Error listening: %v", listenErr)
}
listen := listener.Addr().String()
serverConnChan := make(chan net.Conn)
go func() {
serverConn, err := listener.Accept()
if err != nil {
t.Fatalf("Error accepting: %v", err)
}
serverSpdyConn, err := NewConnection(serverConn, true)
if err != nil {
t.Fatalf("Error creating server connection: %v", err)
}
go serverSpdyConn.Serve(NoOpStreamHandler)
<-serverSpdyConn.CloseChan()
serverConnChan <- serverConn
}()
conn, dialErr := net.Dial("tcp", listen)
if dialErr != nil {
t.Fatalf("Error dialing server: %s", dialErr)
}
spdyConn, spdyErr := NewConnection(conn, false)
if spdyErr != nil {
t.Fatalf("Error creating spdy connection: %s", spdyErr)
}
go spdyConn.Serve(NoOpStreamHandler)
// close client conn
err := conn.Close()
if err != nil {
t.Fatalf("Error closing client connection: %v", err)
}
var serverConn net.Conn
select {
case serverConn = <-serverConnChan:
}
err = serverConn.Close()
if err != nil {
t.Fatalf("Error closing serverConn: %v", err)
}
listenErr = listener.Close()
if listenErr != nil {
t.Fatalf("Error closing listener: %s", listenErr)
}
}
func TestIdleShutdownRace(t *testing.T) {
var wg sync.WaitGroup
server, listen, serverErr := runServer(&wg)
if serverErr != nil {
t.Fatalf("Error initializing server: %s", serverErr)
}
conn, dialErr := net.Dial("tcp", listen)
if dialErr != nil {
t.Fatalf("Error dialing server: %s", dialErr)
}
spdyConn, spdyErr := NewConnection(conn, false)
if spdyErr != nil {
t.Fatalf("Error creating spdy connection: %s", spdyErr)
}
go spdyConn.Serve(NoOpStreamHandler)
authenticated = true
stream, err := spdyConn.CreateStream(http.Header{}, nil, false)
if err != nil {
t.Fatalf("Error creating stream: %v", err)
}
spdyConn.SetIdleTimeout(5 * time.Millisecond)
go func() {
time.Sleep(5 * time.Millisecond)
stream.Reset()
}()
select {
case <-spdyConn.CloseChan():
case <-time.After(20 * time.Millisecond):
t.Fatal("Timed out waiting for idle connection closure")
}
closeErr := server.Close()
if closeErr != nil {
t.Fatalf("Error shutting down server: %s", closeErr)
}
wg.Wait()
}
func TestIdleNoTimeoutSet(t *testing.T) {
var wg sync.WaitGroup
server, listen, serverErr := runServer(&wg)
if serverErr != nil {
t.Fatalf("Error initializing server: %s", serverErr)
}
conn, dialErr := net.Dial("tcp", listen)
if dialErr != nil {
t.Fatalf("Error dialing server: %s", dialErr)
}
spdyConn, spdyErr := NewConnection(conn, false)
if spdyErr != nil {
t.Fatalf("Error creating spdy connection: %s", spdyErr)
}
go spdyConn.Serve(NoOpStreamHandler)
select {
case <-spdyConn.CloseChan():
t.Fatal("Unexpected connection closure")
case <-time.After(10 * time.Millisecond):
}
closeErr := server.Close()
if closeErr != nil {
t.Fatalf("Error shutting down server: %s", closeErr)
}
wg.Wait()
}
func TestIdleClearTimeout(t *testing.T) {
var wg sync.WaitGroup
server, listen, serverErr := runServer(&wg)
if serverErr != nil {
t.Fatalf("Error initializing server: %s", serverErr)
}
conn, dialErr := net.Dial("tcp", listen)
if dialErr != nil {
t.Fatalf("Error dialing server: %s", dialErr)
}
spdyConn, spdyErr := NewConnection(conn, false)
if spdyErr != nil {
t.Fatalf("Error creating spdy connection: %s", spdyErr)
}
go spdyConn.Serve(NoOpStreamHandler)
spdyConn.SetIdleTimeout(10 * time.Millisecond)
spdyConn.SetIdleTimeout(0)
select {
case <-spdyConn.CloseChan():
t.Fatal("Unexpected connection closure")
case <-time.After(20 * time.Millisecond):
}
closeErr := server.Close()
if closeErr != nil {
t.Fatalf("Error shutting down server: %s", closeErr)
}
wg.Wait()
}
func TestIdleNoData(t *testing.T) {
var wg sync.WaitGroup
server, listen, serverErr := runServer(&wg)
if serverErr != nil {
t.Fatalf("Error initializing server: %s", serverErr)
}
conn, dialErr := net.Dial("tcp", listen)
if dialErr != nil {
t.Fatalf("Error dialing server: %s", dialErr)
}
spdyConn, spdyErr := NewConnection(conn, false)
if spdyErr != nil {
t.Fatalf("Error creating spdy connection: %s", spdyErr)
}
go spdyConn.Serve(NoOpStreamHandler)
spdyConn.SetIdleTimeout(10 * time.Millisecond)
<-spdyConn.CloseChan()
closeErr := server.Close()
if closeErr != nil {
t.Fatalf("Error shutting down server: %s", closeErr)
}
wg.Wait()
}
func TestIdleWithData(t *testing.T) {
var wg sync.WaitGroup
server, listen, serverErr := runServer(&wg)
if serverErr != nil {
t.Fatalf("Error initializing server: %s", serverErr)
}
conn, dialErr := net.Dial("tcp", listen)
if dialErr != nil {
t.Fatalf("Error dialing server: %s", dialErr)
}
spdyConn, spdyErr := NewConnection(conn, false)
if spdyErr != nil {
t.Fatalf("Error creating spdy connection: %s", spdyErr)
}
go spdyConn.Serve(NoOpStreamHandler)
spdyConn.SetIdleTimeout(25 * time.Millisecond)
authenticated = true
stream, err := spdyConn.CreateStream(http.Header{}, nil, false)
if err != nil {
t.Fatalf("Error creating stream: %v", err)
}
writeCh := make(chan struct{})
go func() {
b := []byte{1, 2, 3, 4, 5}
for i := 0; i < 10; i++ {
_, err = stream.Write(b)
if err != nil {
t.Fatalf("Error writing to stream: %v", err)
}
time.Sleep(10 * time.Millisecond)
}
close(writeCh)
}()
writesFinished := false
Loop:
for {
select {
case <-writeCh:
writesFinished = true
case <-spdyConn.CloseChan():
if !writesFinished {
t.Fatal("Connection closed before all writes finished")
}
break Loop
}
}
closeErr := server.Close()
if closeErr != nil {
t.Fatalf("Error shutting down server: %s", closeErr)
}
wg.Wait()
}
func TestIdleRace(t *testing.T) {
var wg sync.WaitGroup
server, listen, serverErr := runServer(&wg)
if serverErr != nil {
t.Fatalf("Error initializing server: %s", serverErr)
}
conn, dialErr := net.Dial("tcp", listen)
if dialErr != nil {
t.Fatalf("Error dialing server: %s", dialErr)
}
spdyConn, spdyErr := NewConnection(conn, false)
if spdyErr != nil {
t.Fatalf("Error creating spdy connection: %s", spdyErr)
}
go spdyConn.Serve(NoOpStreamHandler)
spdyConn.SetIdleTimeout(10 * time.Millisecond)
authenticated = true
for i := 0; i < 10; i++ {
_, err := spdyConn.CreateStream(http.Header{}, nil, false)
if err != nil {
t.Fatalf("Error creating stream: %v", err)
}
}
<-spdyConn.CloseChan()
closeErr := server.Close()
if closeErr != nil {
t.Fatalf("Error shutting down server: %s", closeErr)
}
wg.Wait()
}
func TestHalfClosedIdleTimeout(t *testing.T) {
listener, listenErr := net.Listen("tcp", "localhost:0")
if listenErr != nil {
t.Fatalf("Error listening: %v", listenErr)
}
listen := listener.Addr().String()
go func() {
serverConn, err := listener.Accept()
if err != nil {
t.Fatalf("Error accepting: %v", err)
}
serverSpdyConn, err := NewConnection(serverConn, true)
if err != nil {
t.Fatalf("Error creating server connection: %v", err)
}
go serverSpdyConn.Serve(func(s *Stream) {
s.SendReply(http.Header{}, true)
})
serverSpdyConn.SetIdleTimeout(10 * time.Millisecond)
}()
conn, dialErr := net.Dial("tcp", listen)
if dialErr != nil {
t.Fatalf("Error dialing server: %s", dialErr)
}
spdyConn, spdyErr := NewConnection(conn, false)
if spdyErr != nil {
t.Fatalf("Error creating spdy connection: %s", spdyErr)
}
go spdyConn.Serve(NoOpStreamHandler)
stream, err := spdyConn.CreateStream(http.Header{}, nil, false)
if err != nil {
t.Fatalf("Error creating stream: %v", err)
}
time.Sleep(20 * time.Millisecond)
stream.Reset()
err = spdyConn.Close()
if err != nil {
t.Fatalf("Error closing client spdy conn: %v", err)
}
}
func TestStreamReset(t *testing.T) {
var wg sync.WaitGroup
server, listen, serverErr := runServer(&wg)
if serverErr != nil {
t.Fatalf("Error initializing server: %s", serverErr)
}
conn, dialErr := net.Dial("tcp", listen)
if dialErr != nil {
t.Fatalf("Error dialing server: %s", dialErr)
}
spdyConn, spdyErr := NewConnection(conn, false)
if spdyErr != nil {
t.Fatalf("Error creating spdy connection: %s", spdyErr)
}
go spdyConn.Serve(NoOpStreamHandler)
authenticated = true
stream, streamErr := spdyConn.CreateStream(http.Header{}, nil, false)
if streamErr != nil {
t.Fatalf("Error creating stream: %s", streamErr)
}
buf := []byte("dskjahfkdusahfkdsahfkdsafdkas")
for i := 0; i < 10; i++ {
if _, err := stream.Write(buf); err != nil {
t.Fatalf("Error writing to stream: %s", err)
}
}
for i := 0; i < 10; i++ {
if _, err := stream.Read(buf); err != nil {
t.Fatalf("Error reading from stream: %s", err)
}
}
// fmt.Printf("Resetting...\n")
if err := stream.Reset(); err != nil {
t.Fatalf("Error reseting stream: %s", err)
}
closeErr := server.Close()
if closeErr != nil {
t.Fatalf("Error shutting down server: %s", closeErr)
}
wg.Wait()
}
func TestStreamResetWithDataRemaining(t *testing.T) {
var wg sync.WaitGroup
server, listen, serverErr := runServer(&wg)
if serverErr != nil {
t.Fatalf("Error initializing server: %s", serverErr)
}
conn, dialErr := net.Dial("tcp", listen)
if dialErr != nil {
t.Fatalf("Error dialing server: %s", dialErr)
}
spdyConn, spdyErr := NewConnection(conn, false)
if spdyErr != nil {
t.Fatalf("Error creating spdy connection: %s", spdyErr)
}
go spdyConn.Serve(NoOpStreamHandler)
authenticated = true
stream, streamErr := spdyConn.CreateStream(http.Header{}, nil, false)
if streamErr != nil {
t.Fatalf("Error creating stream: %s", streamErr)
}
buf := []byte("dskjahfkdusahfkdsahfkdsafdkas")
for i := 0; i < 10; i++ {
if _, err := stream.Write(buf); err != nil {
t.Fatalf("Error writing to stream: %s", err)
}
}
// read a bit to make sure a goroutine gets to <-dataChan
if _, err := stream.Read(buf); err != nil {
t.Fatalf("Error reading from stream: %s", err)
}
// fmt.Printf("Resetting...\n")
if err := stream.Reset(); err != nil {
t.Fatalf("Error reseting stream: %s", err)
}
closeErr := server.Close()
if closeErr != nil {
t.Fatalf("Error shutting down server: %s", closeErr)
}
wg.Wait()
}
type roundTripper struct {
conn net.Conn
}
func (s *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
r := *req
req = &r
conn, err := net.Dial("tcp", req.URL.Host)
if err != nil {
return nil, err
}
err = req.Write(conn)
if err != nil {
return nil, err
}
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
return nil, err
}
s.conn = conn
return resp, nil
}
// see https://github.com/GoogleCloudPlatform/kubernetes/issues/4882
func TestFramingAfterRemoteConnectionClosed(t *testing.T) {
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
streamCh := make(chan *Stream)
w.WriteHeader(http.StatusSwitchingProtocols)
netconn, _, _ := w.(http.Hijacker).Hijack()
conn, _ := NewConnection(netconn, true)
go conn.Serve(func(s *Stream) {
s.SendReply(http.Header{}, false)
streamCh <- s
})
stream := <-streamCh
io.Copy(stream, stream)
closeChan := make(chan struct{})
go func() {
stream.Reset()
conn.Close()
close(closeChan)
}()
<-closeChan
}))
server.Start()
defer server.Close()
req, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatalf("Error creating request: %s", err)
}
rt := &roundTripper{}
client := &http.Client{Transport: rt}
_, err = client.Do(req)
if err != nil {
t.Fatalf("unexpected error from client.Do: %s", err)
}
conn, err := NewConnection(rt.conn, false)
go conn.Serve(NoOpStreamHandler)
stream, err := conn.CreateStream(http.Header{}, nil, false)
if err != nil {
t.Fatalf("error creating client stream: %s", err)
}
n, err := stream.Write([]byte("hello"))
if err != nil {
t.Fatalf("error writing to stream: %s", err)
}
if n != 5 {
t.Fatalf("Expected to write 5 bytes, but actually wrote %d", n)
}
b := make([]byte, 5)
n, err = stream.Read(b)
if err != nil {
t.Fatalf("error reading from stream: %s", err)
}
if n != 5 {
t.Fatalf("Expected to read 5 bytes, but actually read %d", n)
}
if e, a := "hello", string(b[0:n]); e != a {
t.Fatalf("expected '%s', got '%s'", e, a)
}
stream.Reset()
conn.Close()
}
var authenticated bool
func authStreamHandler(stream *Stream) {
if !authenticated {
stream.Refuse()
}
MirrorStreamHandler(stream)
}
func runServer(wg *sync.WaitGroup) (io.Closer, string, error) {
listener, listenErr := net.Listen("tcp", "localhost:0")
if listenErr != nil {
return nil, "", listenErr
}
wg.Add(1)
go func() {
for {
conn, connErr := listener.Accept()
if connErr != nil {
break
}
spdyConn, _ := NewConnection(conn, true)
go spdyConn.Serve(authStreamHandler)
}
wg.Done()
}()
return listener, listener.Addr().String(), nil
}
package spdystream
import (
"errors"
"fmt"
"io"
"net"
"net/http"
"sync"
"time"
"QmYewWU9ZnQR7Gct9tNZd97i9tGnyCZfNVLM2GGfNEj5jP/spdystream/spdy"
)
var (
ErrUnreadPartialData = errors.New("unread partial data")
)
type Stream struct {
streamId spdy.StreamId
parent *Stream
conn *Connection
startChan chan error
dataLock sync.RWMutex
dataChan chan []byte
unread []byte
priority uint8
headers http.Header
headerChan chan http.Header
finishLock sync.Mutex
finished bool
replyCond *sync.Cond
replied bool
closeLock sync.Mutex
closeChan chan bool
}
// WriteData writes data to stream, sending a dataframe per call
func (s *Stream) WriteData(data []byte, fin bool) error {
s.waitWriteReply()
var flags spdy.DataFlags
if fin {
flags = spdy.DataFlagFin
s.finishLock.Lock()
if s.finished {
s.finishLock.Unlock()
return ErrWriteClosedStream
}
s.finished = true
s.finishLock.Unlock()
}
dataFrame := &spdy.DataFrame{
StreamId: s.streamId,
Flags: flags,
Data: data,
}
debugMessage("(%p) (%d) Writing data frame", s, s.streamId)
return s.conn.framer.WriteFrame(dataFrame)
}
// Write writes bytes to a stream, calling write data for each call.
func (s *Stream) Write(data []byte) (n int, err error) {
err = s.WriteData(data, false)
if err == nil {
n = len(data)
}
return
}
// Read reads bytes from a stream, a single read will never get more
// than what is sent on a single data frame, but a multiple calls to
// read may get data from the same data frame.
func (s *Stream) Read(p []byte) (n int, err error) {
if s.unread == nil {
select {
case <-s.closeChan:
return 0, io.EOF
case read, ok := <-s.dataChan:
if !ok {
return 0, io.EOF
}
s.unread = read
}
}
n = copy(p, s.unread)
if n < len(s.unread) {
s.unread = s.unread[n:]
} else {
s.unread = nil
}
return
}
// ReadData reads an entire data frame and returns the byte array
// from the data frame. If there is unread data from the result
// of a Read call, this function will return an ErrUnreadPartialData.
func (s *Stream) ReadData() ([]byte, error) {
debugMessage("(%p) Reading data from %d", s, s.streamId)
if s.unread != nil {
return nil, ErrUnreadPartialData
}
select {
case <-s.closeChan:
return nil, io.EOF
case read, ok := <-s.dataChan:
if !ok {
return nil, io.EOF
}
return read, nil
}
}
func (s *Stream) waitWriteReply() {
if s.replyCond != nil {
s.replyCond.L.Lock()
for !s.replied {
s.replyCond.Wait()
}
s.replyCond.L.Unlock()
}
}
// Wait waits for the stream to receive a reply.
func (s *Stream) Wait() error {
return s.WaitTimeout(time.Duration(0))
}
// WaitTimeout waits for the stream to receive a reply or for timeout.
// When the timeout is reached, ErrTimeout will be returned.
func (s *Stream) WaitTimeout(timeout time.Duration) error {
var timeoutChan <-chan time.Time
if timeout > time.Duration(0) {
timeoutChan = time.After(timeout)
}
select {
case err := <-s.startChan:
if err != nil {
return err
}
break
case <-timeoutChan:
return ErrTimeout
}
return nil
}
// Close closes the stream by sending an empty data frame with the
// finish flag set, indicating this side is finished with the stream.
func (s *Stream) Close() error {
select {
case <-s.closeChan:
// Stream is now fully closed
s.conn.removeStream(s)
default:
break
}
return s.WriteData([]byte{}, true)
}
// Reset sends a reset frame, putting the stream into the fully closed state.
func (s *Stream) Reset() error {
s.conn.removeStream(s)
return s.resetStream()
}
func (s *Stream) resetStream() error {
s.finishLock.Lock()
if s.finished {
s.finishLock.Unlock()
return nil
}
s.finished = true
s.finishLock.Unlock()
s.closeRemoteChannels()
resetFrame := &spdy.RstStreamFrame{
StreamId: s.streamId,
Status: spdy.Cancel,
}
return s.conn.framer.WriteFrame(resetFrame)
}
// CreateSubStream creates a stream using the current as the parent
func (s *Stream) CreateSubStream(headers http.Header, fin bool) (*Stream, error) {
return s.conn.CreateStream(headers, s, fin)
}
// SetPriority sets the stream priority, does not affect the
// remote priority of this stream after Open has been called.
// Valid values are 0 through 7, 0 being the highest priority
// and 7 the lowest.
func (s *Stream) SetPriority(priority uint8) {
s.priority = priority
}
// SendHeader sends a header frame across the stream
func (s *Stream) SendHeader(headers http.Header, fin bool) error {
return s.conn.sendHeaders(headers, s, fin)
}
// SendReply sends a reply on a stream, only valid to be called once
// when handling a new stream
func (s *Stream) SendReply(headers http.Header, fin bool) error {
if s.replyCond == nil {
return errors.New("cannot reply on initiated stream")
}
s.replyCond.L.Lock()
defer s.replyCond.L.Unlock()
if s.replied {
return nil
}
err := s.conn.sendReply(headers, s, fin)
if err != nil {
return err
}
s.replied = true
s.replyCond.Broadcast()
return nil
}
// Refuse sends a reset frame with the status refuse, only
// valid to be called once when handling a new stream. This
// may be used to indicate that a stream is not allowed
// when http status codes are not being used.
func (s *Stream) Refuse() error {
if s.replied {
return nil
}
s.replied = true
return s.conn.sendReset(spdy.RefusedStream, s)
}
// Cancel sends a reset frame with the status canceled. This
// can be used at any time by the creator of the Stream to
// indicate the stream is no longer needed.
func (s *Stream) Cancel() error {
return s.conn.sendReset(spdy.Cancel, s)
}
// ReceiveHeader receives a header sent on the other side
// of the stream. This function will block until a header
// is received or stream is closed.
func (s *Stream) ReceiveHeader() (http.Header, error) {
select {
case <-s.closeChan:
break
case header, ok := <-s.headerChan:
if !ok {
return nil, fmt.Errorf("header chan closed")
}
return header, nil
}
return nil, fmt.Errorf("stream closed")
}
// Parent returns the parent stream
func (s *Stream) Parent() *Stream {
return s.parent
}
// Headers returns the headers used to create the stream
func (s *Stream) Headers() http.Header {
return s.headers
}
// String returns the string version of stream using the
// streamId to uniquely identify the stream
func (s *Stream) String() string {
return fmt.Sprintf("stream:%d", s.streamId)
}
// Identifier returns a 32 bit identifier for the stream
func (s *Stream) Identifier() uint32 {
return uint32(s.streamId)
}
// IsFinished returns whether the stream has finished
// sending data
func (s *Stream) IsFinished() bool {
return s.finished
}
// Implement net.Conn interface
func (s *Stream) LocalAddr() net.Addr {
return s.conn.conn.LocalAddr()
}
func (s *Stream) RemoteAddr() net.Addr {
return s.conn.conn.RemoteAddr()
}
// TODO set per stream values instead of connection-wide
func (s *Stream) SetDeadline(t time.Time) error {
return s.conn.conn.SetDeadline(t)
}
func (s *Stream) SetReadDeadline(t time.Time) error {
return s.conn.conn.SetReadDeadline(t)
}
func (s *Stream) SetWriteDeadline(t time.Time) error {
return s.conn.conn.SetWriteDeadline(t)
}
func (s *Stream) closeRemoteChannels() {
s.closeLock.Lock()
defer s.closeLock.Unlock()
select {
case <-s.closeChan:
default:
close(s.closeChan)
s.dataLock.Lock()
defer s.dataLock.Unlock()
close(s.dataChan)
}
}
package spdystream
import (
"log"
"os"
)
var (
DEBUG = os.Getenv("DEBUG")
)
func debugMessage(fmt string, args ...interface{}) {
if DEBUG != "" {
log.Printf(fmt, args...)
}
}
package ws
import (
"QmNvACkuNdmJwK4SBHLrxDjEerWqSFnd2qy7CKcn4ouZ3p/websocket"
"io"
"log"
"time"
)
// Wrap an HTTP2 connection over WebSockets and
// use the underlying WebSocket framing for proxy
// compatibility.
type Conn struct {
*websocket.Conn
reader io.Reader
}
func NewConnection(w *websocket.Conn) *Conn {
return &Conn{Conn: w}
}
func (c Conn) Write(b []byte) (int, error) {
err := c.WriteMessage(websocket.BinaryMessage, b)
if err != nil {
return 0, err
}
return len(b), nil
}
func (c Conn) Read(b []byte) (int, error) {
if c.reader == nil {
t, r, err := c.NextReader()
if err != nil {
return 0, err
}
if t != websocket.BinaryMessage {
log.Printf("ws: ignored non-binary message in stream")
return 0, nil
}
c.reader = r
}
n, err := c.reader.Read(b)
if err != nil {
if err == io.EOF {
c.reader = nil
}
return n, err
}
return n, nil
}
func (c Conn) SetDeadline(t time.Time) error {
if err := c.Conn.SetReadDeadline(t); err != nil {
return err
}
if err := c.Conn.SetWriteDeadline(t); err != nil {
return err
}
return nil
}
func (c Conn) Close() error {
err := c.Conn.Close()
return err
}
# go-multiplex
A super simple stream muxing library compatible with [multiplex](http://github.com/maxogden/multiplex)
## Usage
```go
mplex := multiplex.NewMultiplex(mysocket)
s := mplex.NewStream()
s.Write([]byte("Hello World!")
s.Close()
mplex.Serve(func(s *multiplex.Stream) {
// echo back everything received
io.Copy(s, s)
})
```
package multiplex
import (
"bufio"
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
"sync"
)
const (
NewStream = iota
Receiver
Initiator
Unknown
Close
)
var _ = ioutil.ReadAll
var _ = bufio.NewReadWriter
var _ = binary.MaxVarintLen16
type msg struct {
header uint64
data []byte
err chan<- error
}
type Stream struct {
id uint64
name string
header uint64
closed chan struct{}
data_in chan []byte
data_out chan<- msg
extra []byte
}
func newStream(id uint64, name string, initiator bool, send chan<- msg) *Stream {
var hfn uint64
if initiator {
hfn = 2
} else {
hfn = 1
}
return &Stream{
id: id,
name: name,
header: (id << 3) | hfn,
data_in: make(chan []byte, 8),
data_out: send,
closed: make(chan struct{}),
}
}
func (s *Stream) Name() string {
return s.name
}
func (s *Stream) receive(b []byte) {
select {
case s.data_in <- b:
case <-s.closed:
}
}
func (m *Multiplex) Accept() (*Stream, error) {
select {
case s, ok := <-m.nstreams:
if !ok {
return nil, errors.New("multiplex closed")
}
return s, nil
case err := <-m.errs:
return nil, err
case <-m.closed:
return nil, errors.New("multiplex closed")
}
}
func (s *Stream) Read(b []byte) (int, error) {
if s.extra == nil {
select {
case <-s.closed:
return 0, io.EOF
case read, ok := <-s.data_in:
if !ok {
return 0, io.EOF
}
s.extra = read
}
}
n := copy(b, s.extra)
if n < len(s.extra) {
s.extra = s.extra[n:]
} else {
s.extra = nil
}
return n, nil
}
func (s *Stream) Write(b []byte) (int, error) {
errs := make(chan error, 1)
select {
case s.data_out <- msg{header: s.header, data: b, err: errs}:
select {
case err := <-errs:
return len(b), err
case <-s.closed:
return 0, errors.New("stream closed")
}
case <-s.closed:
return 0, errors.New("stream closed")
}
}
func (s *Stream) Close() error {
select {
case <-s.closed:
return nil
default:
close(s.closed)
select {
case s.data_out <- msg{
header: (s.id << 3) | Close,
err: make(chan error, 1), //throw away error, whatever
}:
default:
}
close(s.data_in)
return nil
}
}
type Multiplex struct {
con io.ReadWriteCloser
buf *bufio.Reader
nextID uint64
outchan chan msg
closed chan struct{}
initiator bool
nstreams chan *Stream
errs chan error
channels map[uint64]*Stream
ch_lock sync.Mutex
}
func NewMultiplex(con io.ReadWriteCloser, initiator bool) *Multiplex {
mp := &Multiplex{
con: con,
initiator: initiator,
buf: bufio.NewReader(con),
channels: make(map[uint64]*Stream),
outchan: make(chan msg),
closed: make(chan struct{}),
nstreams: make(chan *Stream, 16),
errs: make(chan error),
}
go mp.handleOutgoing()
go mp.handleIncoming()
return mp
}
func (mp *Multiplex) Close() error {
if mp.IsClosed() {
return nil
}
close(mp.closed)
mp.ch_lock.Lock()
defer mp.ch_lock.Unlock()
for _, s := range mp.channels {
err := s.Close()
if err != nil {
return err
}
}
return nil
}
func (mp *Multiplex) IsClosed() bool {
select {
case <-mp.closed:
return true
default:
return false
}
}
func (mp *Multiplex) handleOutgoing() {
for {
select {
case msg, ok := <-mp.outchan:
if !ok {
return
}
buf := EncodeVarint(msg.header)
_, err := mp.con.Write(buf)
if err != nil {
msg.err <- err
continue
}
buf = EncodeVarint(uint64(len(msg.data)))
_, err = mp.con.Write(buf)
if err != nil {
msg.err <- err
continue
}
_, err = mp.con.Write(msg.data)
if err != nil {
msg.err <- err
continue
}
msg.err <- nil
case <-mp.closed:
return
}
}
}
func (mp *Multiplex) nextChanID() (out uint64) {
if mp.initiator {
out = mp.nextID + 1
} else {
out = mp.nextID
}
mp.nextID += 2
return
}
func (mp *Multiplex) NewStream() *Stream {
return mp.NewNamedStream("")
}
func (mp *Multiplex) NewNamedStream(name string) *Stream {
mp.ch_lock.Lock()
sid := mp.nextChanID()
header := (sid << 3) | NewStream
if name == "" {
name = fmt.Sprint(sid)
}
s := newStream(sid, name, true, mp.outchan)
mp.channels[sid] = s
mp.ch_lock.Unlock()
mp.outchan <- msg{
header: header,
data: []byte(name),
err: make(chan error, 1), //throw away error
}
return s
}
func (mp *Multiplex) sendErr(err error) {
select {
case mp.errs <- err:
case <-mp.closed:
}
}
func (mp *Multiplex) handleIncoming() {
defer mp.shutdown()
for {
ch, tag, err := mp.readNextHeader()
if err != nil {
mp.sendErr(err)
return
}
b, err := mp.readNext()
if err != nil {
mp.sendErr(err)
return
}
mp.ch_lock.Lock()
msch, ok := mp.channels[ch]
if !ok {
var name string
if tag == NewStream {
name = string(b)
}
msch = newStream(ch, name, false, mp.outchan)
mp.channels[ch] = msch
select {
case mp.nstreams <- msch:
case <-mp.closed:
return
}
if tag == NewStream {
mp.ch_lock.Unlock()
continue
}
}
mp.ch_lock.Unlock()
if tag == Close {
msch.Close()
mp.ch_lock.Lock()
delete(mp.channels, ch)
mp.ch_lock.Unlock()
continue
}
msch.receive(b)
}
}
func (mp *Multiplex) shutdown() {
mp.ch_lock.Lock()
defer mp.ch_lock.Unlock()
for _, s := range mp.channels {
s.Close()
}
}
func (mp *Multiplex) readNextHeader() (uint64, uint64, error) {
h, _, err := DecodeVarint(mp.buf)
if err != nil {
return 0, 0, err
}
// get channel ID
ch := h >> 3
rem := h & 7
return ch, rem, nil
}
func (mp *Multiplex) readNext() ([]byte, error) {
// get length
l, _, err := DecodeVarint(mp.buf)
if err != nil {
return nil, err
}
buf := make([]byte, l)
n, err := io.ReadFull(mp.buf, buf)
if err != nil {
return nil, err
}
if n != int(l) {
panic("NOT THE SAME")
}
return buf, nil
}
func EncodeVarint(x uint64) []byte {
var buf [10]byte
var n int
for n = 0; x > 127; n++ {
buf[n] = 0x80 | uint8(x&0x7F)
x >>= 7
}
buf[n] = uint8(x)
n++
return buf[0:n]
}
func DecodeVarint(r *bufio.Reader) (x uint64, n int, err error) {
// x, n already 0
for shift := uint(0); shift < 64; shift += 7 {
val, err := r.ReadByte()
if err != nil {
return 0, 0, err
}
b := uint64(val)
n++
x |= (b & 0x7F) << shift
if (b & 0x80) == 0 {
return x, n, nil
}
}
// The number is too large to represent in a 64-bit value.
return 0, 0, errors.New("Too large of a number!")
}
package multiplex
import (
"fmt"
"io"
"net"
"testing"
rand "QmciEePSP8wpGYZp8fsPFi49Ya7xQMUFwFj2z5fDpfZnhC/randbo"
)
func TestBasicStreams(t *testing.T) {
a, b := net.Pipe()
mpa := NewMultiplex(a, false)
mpb := NewMultiplex(b, true)
mes := []byte("Hello world")
go func() {
s, err := mpb.Accept()
if err != nil {
t.Fatal(err)
}
_, err = s.Write(mes)
if err != nil {
t.Fatal(err)
}
err = s.Close()
if err != nil {
t.Fatal(err)
}
}()
s := mpa.NewStream()
buf := make([]byte, len(mes))
n, err := s.Read(buf)
if err != nil {
t.Fatal(err)
}
if n != len(mes) {
t.Fatal("read wrong amount")
}
if string(buf) != string(mes) {
t.Fatal("got bad data")
}
s.Close()
mpa.Close()
mpb.Close()
}
func TestEcho(t *testing.T) {
a, b := net.Pipe()
mpa := NewMultiplex(a, false)
mpb := NewMultiplex(b, true)
mes := make([]byte, 40960)
rand.New().Read(mes)
go func() {
s, err := mpb.Accept()
if err != nil {
t.Fatal(err)
}
defer s.Close()
io.Copy(s, s)
}()
s := mpa.NewStream()
_, err := s.Write(mes)
if err != nil {
t.Fatal(err)
}
buf := make([]byte, len(mes))
n, err := io.ReadFull(s, buf)
if err != nil {
t.Fatal(err)
}
if n != len(mes) {
t.Fatal("read wrong amount")
}
if err := arrComp(buf, mes); err != nil {
t.Fatal(err)
}
s.Close()
mpa.Close()
mpb.Close()
}
func arrComp(a, b []byte) error {
msg := ""
if len(a) != len(b) {
msg += fmt.Sprintf("arrays differ in length: %d %d\n", len(a), len(b))
}
for i := 0; i < len(a) && i < len(b); i++ {
if a[i] != b[i] {
msg += fmt.Sprintf("content differs at index %d [%d != %d]", i, a[i], b[i])
return fmt.Errorf(msg)
}
}
if len(msg) > 0 {
return fmt.Errorf(msg)
}
return nil
}
{
"name": "go-multiplex",
"author": "whyrusleeping",
"version": "1.0.0",
"gxDependencies": [
{
"name": "randbo",
"hash": "QmciEePSP8wpGYZp8fsPFi49Ya7xQMUFwFj2z5fDpfZnhC",
"version": "1.0.0"
}
],
"language": "go",
"gx": {
"dvcsimport": "github.com/whyrusleeping/go-multiplex"
}
}
\ No newline at end of file
A fast random number `io.Reader` implementation.
![randbo](https://raw.github.com/dustin/randbo/master/randbo.png)
> IN A WORLD where no integer sequence is certain ...
>
> ONE MAN must become statistically indistinguishable from noise
>
> THIS SUMMER, entropy has a new name: RANDBO
(thanks @snej)
{
"name": "randbo",
"author": "whyrusleeping",
"version": "1.0.0",
"language": "go",
"gx": {
"dvcsimport": "github.com/dustin/randbo"
}
}
\ No newline at end of file
package randbo
import (
"io"
"math/rand"
"time"
)
// Randbo creates a stream of non-crypto quality random bytes
type randbo struct {
rand.Source
}
// New creates a new random reader with a time source.
func New() io.Reader {
return NewFrom(rand.NewSource(time.Now().UnixNano()))
}
// NewFrom creates a new reader from your own rand.Source
func NewFrom(src rand.Source) io.Reader {
return &randbo{src}
}
// Read satisfies io.Reader
func (r *randbo) Read(p []byte) (n int, err error) {
todo := len(p)
offset := 0
for {
val := int64(r.Int63())
for i := 0; i < 8; i++ {
p[offset] = byte(val)
todo--
if todo == 0 {
return len(p), nil
}
offset++
val >>= 8
}
}
}
package randbo
import (
"crypto/rand"
"io"
"io/ioutil"
"testing"
)
func TestRandbo(t *testing.T) {
buf := make([]byte, 16)
n, err := New().Read(buf)
if err != nil {
t.Fatalf("Error reading: %v", err)
}
if n != len(buf) {
t.Fatalf("Short read: %v", n)
}
t.Logf("Read %x", buf)
}
const toCopy = 1024 * 1024
func BenchmarkRandbo(b *testing.B) {
b.SetBytes(toCopy)
r := New()
for i := 0; i < b.N; i++ {
io.CopyN(ioutil.Discard, r, toCopy)
}
}
func BenchmarkCrypto(b *testing.B) {
b.SetBytes(toCopy)
for i := 0; i < b.N; i++ {
io.CopyN(ioutil.Discard, rand.Reader, toCopy)
}
}
#Multistream-select router
This package implements a simple stream router for the multistream-select protocol.
The protocol is defined [here](https://github.com/jbenet/multistream).
Usage:
```go
package main
import (
"fmt"
ms "github.com/whyrusleeping/go-multistream"
"io"
"net"
)
func main() {
mux := ms.NewMultistreamMuxer()
mux.AddHandler("/cats", func(rwc io.ReadWriteCloser) error {
fmt.Fprintln(rwc, "HELLO I LIKE CATS")
return rwc.Close()
})
mux.AddHandler("/dogs", func(rwc io.ReadWriteCloser) error {
fmt.Fprintln(rwc, "HELLO I LIKE DOGS")
return rwc.Close()
})
list, err := net.Listen("tcp", ":8765")
if err != nil {
panic(err)
}
for {
con, err := list.Accept()
if err != nil {
panic(err)
}
go mux.Handle(con)
}
}
```
package multistream
import (
"errors"
"io"
)
var ErrNotSupported = errors.New("protocol not supported")
func SelectProtoOrFail(proto string, rwc io.ReadWriteCloser) error {
err := handshake(rwc)
if err != nil {
return err
}
return trySelect(proto, rwc)
}
func SelectOneOf(protos []string, rwc io.ReadWriteCloser) (string, error) {
err := handshake(rwc)
if err != nil {
return "", err
}
for _, p := range protos {
err := trySelect(p, rwc)
switch err {
case nil:
return p, nil
case ErrNotSupported:
default:
return "", err
}
}
return "", ErrNotSupported
}
func handshake(rwc io.ReadWriteCloser) error {
tok, err := ReadNextToken(rwc)
if err != nil {
return err
}
if tok != ProtocolID {
return errors.New("received mismatch in protocol id")
}
err = delimWrite(rwc, []byte(ProtocolID))
if err != nil {
return err
}
return nil
}
func trySelect(proto string, rwc io.ReadWriteCloser) error {
err := delimWrite(rwc, []byte(proto))
if err != nil {
return err
}
tok, err := ReadNextToken(rwc)
if err != nil {
return err
}
switch tok {
case proto:
return nil
case "na":
return ErrNotSupported
default:
return errors.New("unrecognized response: " + tok)
}
}
package multistream
import (
"fmt"
"io"
"sync"
)
type Multistream interface {
io.ReadWriteCloser
Protocol() string
}
func NewMSSelect(c io.ReadWriteCloser, proto string) Multistream {
return NewMultistream(NewMultistream(c, ProtocolID), proto)
}
func NewMultistream(c io.ReadWriteCloser, proto string) Multistream {
return &lazyConn{
proto: proto,
con: c,
}
}
type lazyConn struct {
rhandshake bool // only accessed by 'Read' should not call read async
rhlock sync.Mutex
rhsync bool //protected by mutex
rerr error
whandshake bool
whlock sync.Mutex
whsync bool
werr error
proto string
con io.ReadWriteCloser
}
func (l *lazyConn) Protocol() string {
return l.proto
}
func (l *lazyConn) Read(b []byte) (int, error) {
if !l.rhandshake {
go l.writeHandshake()
err := l.readHandshake()
if err != nil {
return 0, err
}
l.rhandshake = true
}
if len(b) == 0 {
return 0, nil
}
return l.con.Read(b)
}
func (l *lazyConn) readHandshake() error {
l.rhlock.Lock()
defer l.rhlock.Unlock()
// if we've already done this, exit
if l.rhsync {
return l.rerr
}
l.rhsync = true
// read protocol
tok, err := ReadNextToken(l.con)
if err != nil {
l.rerr = err
return err
}
if tok != l.proto {
l.rerr = fmt.Errorf("protocol mismatch in lazy handshake ( %s != %s )", tok, l.proto)
return l.rerr
}
return nil
}
func (l *lazyConn) writeHandshake() error {
l.whlock.Lock()
defer l.whlock.Unlock()
if l.whsync {
return l.werr
}
l.whsync = true
err := delimWrite(l.con, []byte(l.proto))
if err != nil {
l.werr = err
return err
}
return nil
}
func (l *lazyConn) Write(b []byte) (int, error) {
if !l.whandshake {
go l.readHandshake()
err := l.writeHandshake()
if err != nil {
return 0, err
}
l.whandshake = true
}
return l.con.Write(b)
}
func (l *lazyConn) Close() error {
return l.con.Close()
}
package multistream
import (
"bytes"
"encoding/binary"
"errors"
"io"
"sync"
)
var ErrTooLarge = errors.New("incoming message was too large")
const ProtocolID = "/multistream/1.0.0"
type HandlerFunc func(io.ReadWriteCloser) error
type MultistreamMuxer struct {
handlerlock sync.Mutex
handlers map[string]HandlerFunc
}
func NewMultistreamMuxer() *MultistreamMuxer {
return &MultistreamMuxer{handlers: make(map[string]HandlerFunc)}
}
func writeUvarint(w io.Writer, i uint64) error {
varintbuf := make([]byte, 32)
n := binary.PutUvarint(varintbuf, i)
_, err := w.Write(varintbuf[:n])
if err != nil {
return err
}
return nil
}
func delimWrite(w io.Writer, mes []byte) error {
err := writeUvarint(w, uint64(len(mes)+1))
if err != nil {
return err
}
_, err = w.Write(mes)
if err != nil {
return err
}
_, err = w.Write([]byte{'\n'})
if err != nil {
return err
}
return nil
}
func (msm *MultistreamMuxer) AddHandler(protocol string, handler HandlerFunc) {
msm.handlerlock.Lock()
msm.handlers[protocol] = handler
msm.handlerlock.Unlock()
}
func (msm *MultistreamMuxer) RemoveHandler(protocol string) {
msm.handlerlock.Lock()
delete(msm.handlers, protocol)
msm.handlerlock.Unlock()
}
func (msm *MultistreamMuxer) Protocols() []string {
var out []string
msm.handlerlock.Lock()
for k, _ := range msm.handlers {
out = append(out, k)
}
msm.handlerlock.Unlock()
return out
}
func (msm *MultistreamMuxer) Negotiate(rwc io.ReadWriteCloser) (string, HandlerFunc, error) {
// Send our protocol ID
err := delimWrite(rwc, []byte(ProtocolID))
if err != nil {
return "", nil, err
}
line, err := ReadNextToken(rwc)
if err != nil {
return "", nil, err
}
if line != ProtocolID {
rwc.Close()
return "", nil, errors.New("client connected with incorrect version")
}
loop:
for {
// Now read and respond to commands until they send a valid protocol id
tok, err := ReadNextToken(rwc)
if err != nil {
return "", nil, err
}
switch tok {
case "ls":
err := msm.Ls(rwc)
if err != nil {
return "", nil, err
}
default:
msm.handlerlock.Lock()
h, ok := msm.handlers[tok]
msm.handlerlock.Unlock()
if !ok {
err := delimWrite(rwc, []byte("na"))
if err != nil {
return "", nil, err
}
continue loop
}
err := delimWrite(rwc, []byte(tok))
if err != nil {
return "", nil, err
}
// hand off processing to the sub-protocol handler
return tok, h, nil
}
}
}
func (msm *MultistreamMuxer) Ls(rwc io.Writer) error {
buf := new(bytes.Buffer)
msm.handlerlock.Lock()
for proto, _ := range msm.handlers {
err := delimWrite(buf, []byte(proto))
if err != nil {
msm.handlerlock.Unlock()
return err
}
}
msm.handlerlock.Unlock()
err := delimWrite(rwc, buf.Bytes())
if err != nil {
return err
}
return nil
}
func (msm *MultistreamMuxer) Handle(rwc io.ReadWriteCloser) error {
_, h, err := msm.Negotiate(rwc)
if err != nil {
return err
}
return h(rwc)
}
func ReadNextToken(rw io.ReadWriter) (string, error) {
br := &byteReader{rw}
length, err := binary.ReadUvarint(br)
if err != nil {
return "", err
}
if length > 64*1024 {
err := delimWrite(rw, []byte("messages over 64k are not allowed"))
if err != nil {
return "", err
}
return "", ErrTooLarge
}
buf := make([]byte, length)
_, err = io.ReadFull(rw, buf)
if err != nil {
return "", err
}
if len(buf) == 0 || buf[length-1] != '\n' {
return "", errors.New("message did not have trailing newline")
}
// slice off the trailing newline
buf = buf[:length-1]
return string(buf), nil
}
// byteReader implements the ByteReader interface that ReadUVarint requires
type byteReader struct {
io.Reader
}
func (br *byteReader) ReadByte() (byte, error) {
var b [1]byte
_, err := br.Read(b[:])
if err != nil {
return 0, err
}
return b[0], nil
}
package multistream
import (
"crypto/rand"
"io"
"net"
"testing"
"time"
)
func TestProtocolNegotiation(t *testing.T) {
a, b := net.Pipe()
mux := NewMultistreamMuxer()
mux.AddHandler("/a", nil)
mux.AddHandler("/b", nil)
mux.AddHandler("/c", nil)
done := make(chan struct{})
go func() {
selected, _, err := mux.Negotiate(a)
if err != nil {
t.Fatal(err)
}
if selected != "/a" {
t.Fatal("incorrect protocol selected")
}
close(done)
}()
err := SelectProtoOrFail("/a", b)
if err != nil {
t.Fatal(err)
}
select {
case <-time.After(time.Second):
t.Fatal("protocol negotiation didnt complete")
case <-done:
}
verifyPipe(t, a, b)
}
func TestSelectOne(t *testing.T) {
a, b := net.Pipe()
mux := NewMultistreamMuxer()
mux.AddHandler("/a", nil)
mux.AddHandler("/b", nil)
mux.AddHandler("/c", nil)
done := make(chan struct{})
go func() {
selected, _, err := mux.Negotiate(a)
if err != nil {
t.Fatal(err)
}
if selected != "/c" {
t.Fatal("incorrect protocol selected")
}
close(done)
}()
sel, err := SelectOneOf([]string{"/d", "/e", "/c"}, b)
if err != nil {
t.Fatal(err)
}
if sel != "/c" {
t.Fatal("selected wrong protocol")
}
select {
case <-time.After(time.Second):
t.Fatal("protocol negotiation didnt complete")
case <-done:
}
verifyPipe(t, a, b)
}
func TestSelectOneAndWrite(t *testing.T) {
a, b := net.Pipe()
mux := NewMultistreamMuxer()
mux.AddHandler("/a", nil)
mux.AddHandler("/b", nil)
mux.AddHandler("/c", nil)
done := make(chan struct{})
go func() {
selected, _, err := mux.Negotiate(a)
if err != nil {
t.Fatal(err)
}
if selected != "/c" {
t.Fatal("incorrect protocol selected")
}
close(done)
}()
sel, err := SelectOneOf([]string{"/d", "/e", "/c"}, b)
if err != nil {
t.Fatal(err)
}
if sel != "/c" {
t.Fatal("selected wrong protocol")
}
select {
case <-time.After(time.Second):
t.Fatal("protocol negotiation didnt complete")
case <-done:
}
verifyPipe(t, a, b)
}
func TestLazyConns(t *testing.T) {
a, b := net.Pipe()
mux := NewMultistreamMuxer()
mux.AddHandler("/a", nil)
mux.AddHandler("/b", nil)
mux.AddHandler("/c", nil)
la := NewMSSelect(a, "/c")
lb := NewMSSelect(b, "/c")
verifyPipe(t, la, lb)
}
func TestLazyAndMux(t *testing.T) {
a, b := net.Pipe()
mux := NewMultistreamMuxer()
mux.AddHandler("/a", nil)
mux.AddHandler("/b", nil)
mux.AddHandler("/c", nil)
done := make(chan struct{})
go func() {
selected, _, err := mux.Negotiate(a)
if err != nil {
t.Fatal(err)
}
if selected != "/c" {
t.Fatal("incorrect protocol selected")
}
msg := make([]byte, 5)
_, err = a.Read(msg)
if err != nil {
t.Fatal(err)
}
close(done)
}()
lb := NewMSSelect(b, "/c")
// do a write to push the handshake through
_, err := lb.Write([]byte("hello"))
if err != nil {
t.Fatal(err)
}
select {
case <-time.After(time.Second):
t.Fatal("failed to complete in time")
case <-done:
}
verifyPipe(t, a, lb)
}
func TestLazyAndMuxWrite(t *testing.T) {
a, b := net.Pipe()
mux := NewMultistreamMuxer()
mux.AddHandler("/a", nil)
mux.AddHandler("/b", nil)
mux.AddHandler("/c", nil)
done := make(chan struct{})
go func() {
selected, _, err := mux.Negotiate(a)
if err != nil {
t.Fatal(err)
}
if selected != "/c" {
t.Fatal("incorrect protocol selected")
}
_, err = a.Write([]byte("hello"))
if err != nil {
t.Fatal(err)
}
close(done)
}()
lb := NewMSSelect(b, "/c")
// do a write to push the handshake through
msg := make([]byte, 5)
_, err := lb.Read(msg)
if err != nil {
t.Fatal(err)
}
if string(msg) != "hello" {
t.Fatal("wrong!")
}
select {
case <-time.After(time.Second):
t.Fatal("failed to complete in time")
case <-done:
}
verifyPipe(t, a, lb)
}
func verifyPipe(t *testing.T, a, b io.ReadWriter) {
mes := make([]byte, 1024)
rand.Read(mes)
go func() {
b.Write(mes)
a.Write(mes)
}()
buf := make([]byte, len(mes))
n, err := a.Read(buf)
if err != nil {
t.Fatal(err)
}
if n != len(buf) {
t.Fatal("failed to read enough")
}
if string(buf) != string(mes) {
t.Fatal("somehow read wrong message")
}
n, err = b.Read(buf)
if err != nil {
t.Fatal(err)
}
if n != len(buf) {
t.Fatal("failed to read enough")
}
if string(buf) != string(mes) {
t.Fatal("somehow read wrong message")
}
}
{
"name": "go-multistream",
"author": "whyrusleeping",
"version": "1.0.0",
"language": "go",
"gx": {
"dvcsimport": "github.com/whyrusleeping/go-multistream"
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment