Commit d91b419e authored by Jeromy's avatar Jeromy
Browse files

WIP

parent a40ef343
package yamux
import (
"testing"
)
func BenchmarkPing(b *testing.B) {
client, server := testClientServer()
defer client.Close()
defer server.Close()
for i := 0; i < b.N; i++ {
rtt, err := client.Ping()
if err != nil {
b.Fatalf("err: %v", err)
}
if rtt == 0 {
b.Fatalf("bad: %v", rtt)
}
}
}
func BenchmarkAccept(b *testing.B) {
client, server := testClientServer()
defer client.Close()
defer server.Close()
go func() {
for i := 0; i < b.N; i++ {
stream, err := server.AcceptStream()
if err != nil {
return
}
stream.Close()
}
}()
for i := 0; i < b.N; i++ {
stream, err := client.Open()
if err != nil {
b.Fatalf("err: %v", err)
}
stream.Close()
}
}
func BenchmarkSendRecv(b *testing.B) {
client, server := testClientServer()
defer client.Close()
defer server.Close()
sendBuf := make([]byte, 512)
recvBuf := make([]byte, 512)
doneCh := make(chan struct{})
go func() {
stream, err := server.AcceptStream()
if err != nil {
return
}
defer stream.Close()
for i := 0; i < b.N; i++ {
if _, err := stream.Read(recvBuf); err != nil {
b.Fatalf("err: %v", err)
}
}
close(doneCh)
}()
stream, err := client.Open()
if err != nil {
b.Fatalf("err: %v", err)
}
defer stream.Close()
for i := 0; i < b.N; i++ {
if _, err := stream.Write(sendBuf); err != nil {
b.Fatalf("err: %v", err)
}
}
<-doneCh
}
package yamux
import (
"encoding/binary"
"fmt"
)
var (
// ErrInvalidVersion means we received a frame with an
// invalid version
ErrInvalidVersion = fmt.Errorf("invalid protocol version")
// ErrInvalidMsgType means we received a frame with an
// invalid message type
ErrInvalidMsgType = fmt.Errorf("invalid msg type")
// ErrSessionShutdown is used if there is a shutdown during
// an operation
ErrSessionShutdown = fmt.Errorf("session shutdown")
// ErrStreamsExhausted is returned if we have no more
// stream ids to issue
ErrStreamsExhausted = fmt.Errorf("streams exhausted")
// ErrDuplicateStream is used if a duplicate stream is
// opened inbound
ErrDuplicateStream = fmt.Errorf("duplicate stream initiated")
// ErrReceiveWindowExceeded indicates the window was exceeded
ErrRecvWindowExceeded = fmt.Errorf("recv window exceeded")
// ErrTimeout is used when we reach an IO deadline
ErrTimeout = fmt.Errorf("i/o deadline reached")
// ErrStreamClosed is returned when using a closed stream
ErrStreamClosed = fmt.Errorf("stream closed")
// ErrUnexpectedFlag is set when we get an unexpected flag
ErrUnexpectedFlag = fmt.Errorf("unexpected flag")
// ErrRemoteGoAway is used when we get a go away from the other side
ErrRemoteGoAway = fmt.Errorf("remote end is not accepting connections")
// ErrConnectionReset is sent if a stream is reset. This can happen
// if the backlog is exceeded, or if there was a remote GoAway.
ErrConnectionReset = fmt.Errorf("connection reset")
)
const (
// protoVersion is the only version we support
protoVersion uint8 = 0
)
const (
// Data is used for data frames. They are followed
// by length bytes worth of payload.
typeData uint8 = iota
// WindowUpdate is used to change the window of
// a given stream. The length indicates the delta
// update to the window.
typeWindowUpdate
// Ping is sent as a keep-alive or to measure
// the RTT. The StreamID and Length value are echoed
// back in the response.
typePing
// GoAway is sent to terminate a session. The StreamID
// should be 0 and the length is an error code.
typeGoAway
)
const (
// SYN is sent to signal a new stream. May
// be sent with a data payload
flagSYN uint16 = 1 << iota
// ACK is sent to acknowledge a new stream. May
// be sent with a data payload
flagACK
// FIN is sent to half-close the given stream.
// May be sent with a data payload.
flagFIN
// RST is used to hard close a given stream.
flagRST
)
const (
// initialStreamWindow is the initial stream window size
initialStreamWindow uint32 = 256 * 1024
)
const (
// goAwayNormal is sent on a normal termination
goAwayNormal uint32 = iota
// goAwayProtoErr sent on a protocol error
goAwayProtoErr
// goAwayInternalErr sent on an internal error
goAwayInternalErr
)
const (
sizeOfVersion = 1
sizeOfType = 1
sizeOfFlags = 2
sizeOfStreamID = 4
sizeOfLength = 4
headerSize = sizeOfVersion + sizeOfType + sizeOfFlags +
sizeOfStreamID + sizeOfLength
)
type header []byte
func (h header) Version() uint8 {
return h[0]
}
func (h header) MsgType() uint8 {
return h[1]
}
func (h header) Flags() uint16 {
return binary.BigEndian.Uint16(h[2:4])
}
func (h header) StreamID() uint32 {
return binary.BigEndian.Uint32(h[4:8])
}
func (h header) Length() uint32 {
return binary.BigEndian.Uint32(h[8:12])
}
func (h header) String() string {
return fmt.Sprintf("Vsn:%d Type:%d Flags:%d StreamID:%d Length:%d",
h.Version(), h.MsgType(), h.Flags(), h.StreamID(), h.Length())
}
func (h header) encode(msgType uint8, flags uint16, streamID uint32, length uint32) {
h[0] = protoVersion
h[1] = msgType
binary.BigEndian.PutUint16(h[2:4], flags)
binary.BigEndian.PutUint32(h[4:8], streamID)
binary.BigEndian.PutUint32(h[8:12], length)
}
package yamux
import (
"testing"
)
func TestConst(t *testing.T) {
if protoVersion != 0 {
t.Fatalf("bad: %v", protoVersion)
}
if typeData != 0 {
t.Fatalf("bad: %v", typeData)
}
if typeWindowUpdate != 1 {
t.Fatalf("bad: %v", typeWindowUpdate)
}
if typePing != 2 {
t.Fatalf("bad: %v", typePing)
}
if typeGoAway != 3 {
t.Fatalf("bad: %v", typeGoAway)
}
if flagSYN != 1 {
t.Fatalf("bad: %v", flagSYN)
}
if flagACK != 2 {
t.Fatalf("bad: %v", flagACK)
}
if flagFIN != 4 {
t.Fatalf("bad: %v", flagFIN)
}
if flagRST != 8 {
t.Fatalf("bad: %v", flagRST)
}
if goAwayNormal != 0 {
t.Fatalf("bad: %v", goAwayNormal)
}
if goAwayProtoErr != 1 {
t.Fatalf("bad: %v", goAwayProtoErr)
}
if goAwayInternalErr != 2 {
t.Fatalf("bad: %v", goAwayInternalErr)
}
if headerSize != 12 {
t.Fatalf("bad header size")
}
}
func TestEncodeDecode(t *testing.T) {
hdr := header(make([]byte, headerSize))
hdr.encode(typeWindowUpdate, flagACK|flagRST, 1234, 4321)
if hdr.Version() != protoVersion {
t.Fatalf("bad: %v", hdr)
}
if hdr.MsgType() != typeWindowUpdate {
t.Fatalf("bad: %v", hdr)
}
if hdr.Flags() != flagACK|flagRST {
t.Fatalf("bad: %v", hdr)
}
if hdr.StreamID() != 1234 {
t.Fatalf("bad: %v", hdr)
}
if hdr.Length() != 4321 {
t.Fatalf("bad: %v", hdr)
}
}
package yamux
import (
"fmt"
"io"
"os"
"time"
)
// Config is used to tune the Yamux session
type Config struct {
// AcceptBacklog is used to limit how many streams may be
// waiting an accept.
AcceptBacklog int
// EnableKeepalive is used to do a period keep alive
// messages using a ping.
EnableKeepAlive bool
// KeepAliveInterval is how often to perform the keep alive
KeepAliveInterval time.Duration
// MaxStreamWindowSize is used to control the maximum
// window size that we allow for a stream.
MaxStreamWindowSize uint32
// LogOutput is used to control the log destination
LogOutput io.Writer
}
// DefaultConfig is used to return a default configuration
func DefaultConfig() *Config {
return &Config{
AcceptBacklog: 256,
EnableKeepAlive: true,
KeepAliveInterval: 30 * time.Second,
MaxStreamWindowSize: initialStreamWindow,
LogOutput: os.Stderr,
}
}
// VerifyConfig is used to verify the sanity of configuration
func VerifyConfig(config *Config) error {
if config.AcceptBacklog <= 0 {
return fmt.Errorf("backlog must be positive")
}
if config.KeepAliveInterval == 0 {
return fmt.Errorf("keep-alive interval must be positive")
}
if config.MaxStreamWindowSize < initialStreamWindow {
return fmt.Errorf("MaxStreamWindowSize must be larger than %d", initialStreamWindow)
}
return nil
}
// Server is used to initialize a new server-side connection.
// There must be at most one server-side connection. If a nil config is
// provided, the DefaultConfiguration will be used.
func Server(conn io.ReadWriteCloser, config *Config) (*Session, error) {
if config == nil {
config = DefaultConfig()
}
if err := VerifyConfig(config); err != nil {
return nil, err
}
return newSession(config, conn, false), nil
}
// Client is used to initialize a new client-side connection.
// There must be at most one client-side connection.
func Client(conn io.ReadWriteCloser, config *Config) (*Session, error) {
if config == nil {
config = DefaultConfig()
}
if err := VerifyConfig(config); err != nil {
return nil, err
}
return newSession(config, conn, true), nil
}
{
"name": "yamux",
"author": "whyrusleeping",
"version": "1.0.0",
"language": "go",
"gx": {
"dvcsimport": "github.com/hashicorp/yamux"
}
}
\ No newline at end of file
package yamux
import (
"bufio"
"fmt"
"io"
"io/ioutil"
"log"
"math"
"net"
"strings"
"sync"
"sync/atomic"
"time"
)
// Session is used to wrap a reliable ordered connection and to
// multiplex it into multiple streams.
type Session struct {
// remoteGoAway indicates the remote side does
// not want futher connections. Must be first for alignment.
remoteGoAway int32
// localGoAway indicates that we should stop
// accepting futher connections. Must be first for alignment.
localGoAway int32
// nextStreamID is the next stream we should
// send. This depends if we are a client/server.
nextStreamID uint32
// config holds our configuration
config *Config
// logger is used for our logs
logger *log.Logger
// conn is the underlying connection
conn io.ReadWriteCloser
// bufRead is a buffered reader
bufRead *bufio.Reader
// pings is used to track inflight pings
pings map[uint32]chan struct{}
pingID uint32
pingLock sync.Mutex
// streams maps a stream id to a stream
streams map[uint32]*Stream
streamLock sync.Mutex
// synCh acts like a semaphore. It is sized to the AcceptBacklog which
// is assumed to be symmetric between the client and server. This allows
// the client to avoid exceeding the backlog and instead blocks the open.
synCh chan struct{}
// acceptCh is used to pass ready streams to the client
acceptCh chan *Stream
// sendCh is used to mark a stream as ready to send,
// or to send a header out directly.
sendCh chan sendReady
// recvDoneCh is closed when recv() exits to avoid a race
// between stream registration and stream shutdown
recvDoneCh chan struct{}
// shutdown is used to safely close a session
shutdown bool
shutdownErr error
shutdownCh chan struct{}
shutdownLock sync.Mutex
}
// sendReady is used to either mark a stream as ready
// or to directly send a header
type sendReady struct {
Hdr []byte
Body io.Reader
Err chan error
}
// newSession is used to construct a new session
func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
s := &Session{
config: config,
logger: log.New(config.LogOutput, "", log.LstdFlags),
conn: conn,
bufRead: bufio.NewReader(conn),
pings: make(map[uint32]chan struct{}),
streams: make(map[uint32]*Stream),
synCh: make(chan struct{}, config.AcceptBacklog),
acceptCh: make(chan *Stream, config.AcceptBacklog),
sendCh: make(chan sendReady, 64),
recvDoneCh: make(chan struct{}),
shutdownCh: make(chan struct{}),
}
if client {
s.nextStreamID = 1
} else {
s.nextStreamID = 2
}
go s.recv()
go s.send()
if config.EnableKeepAlive {
go s.keepalive()
}
return s
}
// IsClosed does a safe check to see if we have shutdown
func (s *Session) IsClosed() bool {
select {
case <-s.shutdownCh:
return true
default:
return false
}
}
// NumStreams returns the number of currently open streams
func (s *Session) NumStreams() int {
s.streamLock.Lock()
num := len(s.streams)
s.streamLock.Unlock()
return num
}
// Open is used to create a new stream as a net.Conn
func (s *Session) Open() (net.Conn, error) {
return s.OpenStream()
}
// OpenStream is used to create a new stream
func (s *Session) OpenStream() (*Stream, error) {
if s.IsClosed() {
return nil, ErrSessionShutdown
}
if atomic.LoadInt32(&s.remoteGoAway) == 1 {
return nil, ErrRemoteGoAway
}
// Block if we have too many inflight SYNs
select {
case s.synCh <- struct{}{}:
case <-s.shutdownCh:
return nil, ErrSessionShutdown
}
GET_ID:
// Get and ID, and check for stream exhaustion
id := atomic.LoadUint32(&s.nextStreamID)
if id >= math.MaxUint32-1 {
return nil, ErrStreamsExhausted
}
if !atomic.CompareAndSwapUint32(&s.nextStreamID, id, id+2) {
goto GET_ID
}
// Register the stream
stream := newStream(s, id, streamInit)
s.streamLock.Lock()
s.streams[id] = stream
s.streamLock.Unlock()
// Send the window update to create
if err := stream.sendWindowUpdate(); err != nil {
return nil, err
}
return stream, nil
}
// Accept is used to block until the next available stream
// is ready to be accepted.
func (s *Session) Accept() (net.Conn, error) {
return s.AcceptStream()
}
// AcceptStream is used to block until the next available stream
// is ready to be accepted.
func (s *Session) AcceptStream() (*Stream, error) {
select {
case stream := <-s.acceptCh:
if err := stream.sendWindowUpdate(); err != nil {
return nil, err
}
return stream, nil
case <-s.shutdownCh:
return nil, s.shutdownErr
}
}
// Close is used to close the session and all streams.
// Attempts to send a GoAway before closing the connection.
func (s *Session) Close() error {
s.shutdownLock.Lock()
defer s.shutdownLock.Unlock()
if s.shutdown {
return nil
}
s.shutdown = true
if s.shutdownErr == nil {
s.shutdownErr = ErrSessionShutdown
}
close(s.shutdownCh)
s.conn.Close()
<-s.recvDoneCh
s.streamLock.Lock()
defer s.streamLock.Unlock()
for _, stream := range s.streams {
stream.forceClose()
}
return nil
}
// exitErr is used to handle an error that is causing the
// session to terminate.
func (s *Session) exitErr(err error) {
s.shutdownLock.Lock()
if s.shutdownErr == nil {
s.shutdownErr = err
}
s.shutdownLock.Unlock()
s.Close()
}
// GoAway can be used to prevent accepting further
// connections. It does not close the underlying conn.
func (s *Session) GoAway() error {
return s.waitForSend(s.goAway(goAwayNormal), nil)
}
// goAway is used to send a goAway message
func (s *Session) goAway(reason uint32) header {
atomic.SwapInt32(&s.localGoAway, 1)
hdr := header(make([]byte, headerSize))
hdr.encode(typeGoAway, 0, 0, reason)
return hdr
}
// Ping is used to measure the RTT response time
func (s *Session) Ping() (time.Duration, error) {
// Get a channel for the ping
ch := make(chan struct{})
// Get a new ping id, mark as pending
s.pingLock.Lock()
id := s.pingID
s.pingID++
s.pings[id] = ch
s.pingLock.Unlock()
// Send the ping request
hdr := header(make([]byte, headerSize))
hdr.encode(typePing, flagSYN, 0, id)
if err := s.waitForSend(hdr, nil); err != nil {
return 0, err
}
// Wait for a response
start := time.Now()
select {
case <-ch:
case <-s.shutdownCh:
return 0, ErrSessionShutdown
}
// Compute the RTT
return time.Now().Sub(start), nil
}
// keepalive is a long running goroutine that periodically does
// a ping to keep the connection alive.
func (s *Session) keepalive() {
for {
select {
case <-time.After(s.config.KeepAliveInterval):
s.Ping()
case <-s.shutdownCh:
return
}
}
}
// waitForSendErr waits to send a header, checking for a potential shutdown
func (s *Session) waitForSend(hdr header, body io.Reader) error {
errCh := make(chan error, 1)
return s.waitForSendErr(hdr, body, errCh)
}
// waitForSendErr waits to send a header, checking for a potential shutdown
func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error {
ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
select {
case s.sendCh <- ready:
case <-s.shutdownCh:
return ErrSessionShutdown
}
select {
case err := <-errCh:
return err
case <-s.shutdownCh:
return ErrSessionShutdown
}
}
// sendNoWait does a send without waiting
func (s *Session) sendNoWait(hdr header) error {
select {
case s.sendCh <- sendReady{Hdr: hdr}:
return nil
case <-s.shutdownCh:
return ErrSessionShutdown
}
}
// send is a long running goroutine that sends data
func (s *Session) send() {
for {
select {
case ready := <-s.sendCh:
// Send a header if ready
if ready.Hdr != nil {
sent := 0
for sent < len(ready.Hdr) {
n, err := s.conn.Write(ready.Hdr[sent:])
if err != nil {
s.logger.Printf("[ERR] yamux: Failed to write header: %v", err)
asyncSendErr(ready.Err, err)
s.exitErr(err)
return
}
sent += n
}
}
// Send data from a body if given
if ready.Body != nil {
_, err := io.Copy(s.conn, ready.Body)
if err != nil {
s.logger.Printf("[ERR] yamux: Failed to write body: %v", err)
asyncSendErr(ready.Err, err)
s.exitErr(err)
return
}
}
// No error, successful send
asyncSendErr(ready.Err, nil)
case <-s.shutdownCh:
return
}
}
}
// recv is a long running goroutine that accepts new data
func (s *Session) recv() {
if err := s.recvLoop(); err != nil {
s.exitErr(err)
}
}
// recvLoop continues to receive data until a fatal error is encountered
func (s *Session) recvLoop() error {
defer close(s.recvDoneCh)
hdr := header(make([]byte, headerSize))
var handler func(header) error
for {
// Read the header
if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") {
s.logger.Printf("[ERR] yamux: Failed to read header: %v", err)
}
return err
}
// Verify the version
if hdr.Version() != protoVersion {
s.logger.Printf("[ERR] yamux: Invalid protocol version: %d", hdr.Version())
return ErrInvalidVersion
}
// Switch on the type
switch hdr.MsgType() {
case typeData:
handler = s.handleStreamMessage
case typeWindowUpdate:
handler = s.handleStreamMessage
case typeGoAway:
handler = s.handleGoAway
case typePing:
handler = s.handlePing
default:
return ErrInvalidMsgType
}
// Invoke the handler
if err := handler(hdr); err != nil {
return err
}
}
}
// handleStreamMessage handles either a data or window update frame
func (s *Session) handleStreamMessage(hdr header) error {
// Check for a new stream creation
id := hdr.StreamID()
flags := hdr.Flags()
if flags&flagSYN == flagSYN {
if err := s.incomingStream(id); err != nil {
return err
}
}
// Get the stream
s.streamLock.Lock()
stream := s.streams[id]
s.streamLock.Unlock()
// If we do not have a stream, likely we sent a RST
if stream == nil {
// Drain any data on the wire
if hdr.MsgType() == typeData && hdr.Length() > 0 {
s.logger.Printf("[WARN] yamux: Discarding data for stream: %d", id)
if _, err := io.CopyN(ioutil.Discard, s.bufRead, int64(hdr.Length())); err != nil {
s.logger.Printf("[ERR] yamux: Failed to discard data: %v", err)
return nil
}
} else {
s.logger.Printf("[WARN] yamux: frame for missing stream: %v", hdr)
}
return nil
}
// Check if this is a window update
if hdr.MsgType() == typeWindowUpdate {
if err := stream.incrSendWindow(hdr, flags); err != nil {
s.sendNoWait(s.goAway(goAwayProtoErr))
return err
}
return nil
}
// Read the new data
if err := stream.readData(hdr, flags, s.bufRead); err != nil {
s.sendNoWait(s.goAway(goAwayProtoErr))
return err
}
return nil
}
// handlePing is invokde for a typePing frame
func (s *Session) handlePing(hdr header) error {
flags := hdr.Flags()
pingID := hdr.Length()
// Check if this is a query, respond back
if flags&flagSYN == flagSYN {
hdr := header(make([]byte, headerSize))
hdr.encode(typePing, flagACK, 0, pingID)
s.sendNoWait(hdr)
return nil
}
// Handle a response
s.pingLock.Lock()
ch := s.pings[pingID]
if ch != nil {
delete(s.pings, pingID)
close(ch)
}
s.pingLock.Unlock()
return nil
}
// handleGoAway is invokde for a typeGoAway frame
func (s *Session) handleGoAway(hdr header) error {
code := hdr.Length()
switch code {
case goAwayNormal:
atomic.SwapInt32(&s.remoteGoAway, 1)
case goAwayProtoErr:
s.logger.Printf("[ERR] yamux: received protocol error go away")
return fmt.Errorf("yamux protocol error")
case goAwayInternalErr:
s.logger.Printf("[ERR] yamux: received internal error go away")
return fmt.Errorf("remote yamux internal error")
default:
s.logger.Printf("[ERR] yamux: received unexpected go away")
return fmt.Errorf("unexpected go away received")
}
return nil
}
// incomingStream is used to create a new incoming stream
func (s *Session) incomingStream(id uint32) error {
// Reject immediately if we are doing a go away
if atomic.LoadInt32(&s.localGoAway) == 1 {
hdr := header(make([]byte, headerSize))
hdr.encode(typeWindowUpdate, flagRST, id, 0)
return s.sendNoWait(hdr)
}
// Allocate a new stream
stream := newStream(s, id, streamSYNReceived)
s.streamLock.Lock()
defer s.streamLock.Unlock()
// Check if stream already exists
if _, ok := s.streams[id]; ok {
s.logger.Printf("[ERR] yamux: duplicate stream declared")
s.sendNoWait(s.goAway(goAwayProtoErr))
return ErrDuplicateStream
}
// Register the stream
s.streams[id] = stream
// Check if we've exceeded the backlog
select {
case s.acceptCh <- stream:
return nil
default:
// Backlog exceeded! RST the stream
s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset")
delete(s.streams, id)
stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0)
return s.sendNoWait(stream.sendHdr)
}
}
// closeStream is used to close a stream once both sides have
// issued a close.
func (s *Session) closeStream(id uint32) {
s.streamLock.Lock()
delete(s.streams, id)
s.streamLock.Unlock()
}
// establishStream is used to mark a stream that was in the
// SYN Sent state as established.
func (s *Session) establishStream() {
select {
case <-s.synCh:
default:
panic("established stream without inflight syn")
}
}
package yamux
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"runtime"
"sync"
"testing"
"time"
)
type pipeConn struct {
reader *io.PipeReader
writer *io.PipeWriter
}
func (p *pipeConn) Read(b []byte) (int, error) {
return p.reader.Read(b)
}
func (p *pipeConn) Write(b []byte) (int, error) {
return p.writer.Write(b)
}
func (p *pipeConn) Close() error {
p.reader.Close()
return p.writer.Close()
}
func testConn() (io.ReadWriteCloser, io.ReadWriteCloser) {
read1, write1 := io.Pipe()
read2, write2 := io.Pipe()
return &pipeConn{read1, write2}, &pipeConn{read2, write1}
}
func testClientServer() (*Session, *Session) {
conf := DefaultConfig()
conf.AcceptBacklog = 64
conf.KeepAliveInterval = 100 * time.Millisecond
return testClientServerConfig(conf)
}
func testClientServerConfig(conf *Config) (*Session, *Session) {
conn1, conn2 := testConn()
client, _ := Client(conn1, conf)
server, _ := Server(conn2, conf)
return client, server
}
func TestPing(t *testing.T) {
client, server := testClientServer()
defer client.Close()
defer server.Close()
rtt, err := client.Ping()
if err != nil {
t.Fatalf("err: %v", err)
}
if rtt == 0 {
t.Fatalf("bad: %v", rtt)
}
rtt, err = server.Ping()
if err != nil {
t.Fatalf("err: %v", err)
}
if rtt == 0 {
t.Fatalf("bad: %v", rtt)
}
}
func TestAccept(t *testing.T) {
client, server := testClientServer()
defer client.Close()
defer server.Close()
if client.NumStreams() != 0 {
t.Fatalf("bad")
}
if server.NumStreams() != 0 {
t.Fatalf("bad")
}
wg := &sync.WaitGroup{}
wg.Add(4)
go func() {
defer wg.Done()
stream, err := server.AcceptStream()
if err != nil {
t.Fatalf("err: %v", err)
}
if id := stream.StreamID(); id != 1 {
t.Fatalf("bad: %v", id)
}
if err := stream.Close(); err != nil {
t.Fatalf("err: %v", err)
}
}()
go func() {
defer wg.Done()
stream, err := client.AcceptStream()
if err != nil {
t.Fatalf("err: %v", err)
}
if id := stream.StreamID(); id != 2 {
t.Fatalf("bad: %v", id)
}
if err := stream.Close(); err != nil {
t.Fatalf("err: %v", err)
}
}()
go func() {
defer wg.Done()
stream, err := server.OpenStream()
if err != nil {
t.Fatalf("err: %v", err)
}
if id := stream.StreamID(); id != 2 {
t.Fatalf("bad: %v", id)
}
if err := stream.Close(); err != nil {
t.Fatalf("err: %v", err)
}
}()
go func() {
defer wg.Done()
stream, err := client.OpenStream()
if err != nil {
t.Fatalf("err: %v", err)
}
if id := stream.StreamID(); id != 1 {
t.Fatalf("bad: %v", id)
}
if err := stream.Close(); err != nil {
t.Fatalf("err: %v", err)
}
}()
doneCh := make(chan struct{})
go func() {
wg.Wait()
close(doneCh)
}()
select {
case <-doneCh:
case <-time.After(time.Second):
panic("timeout")
}
}
func TestSendData_Small(t *testing.T) {
client, server := testClientServer()
defer client.Close()
defer server.Close()
wg := &sync.WaitGroup{}
wg.Add(2)
go func() {
defer wg.Done()
stream, err := server.AcceptStream()
if err != nil {
t.Fatalf("err: %v", err)
}
if server.NumStreams() != 1 {
t.Fatalf("bad")
}
buf := make([]byte, 4)
for i := 0; i < 1000; i++ {
n, err := stream.Read(buf)
if err != nil {
t.Fatalf("err: %v", err)
}
if n != 4 {
t.Fatalf("short read: %d", n)
}
if string(buf) != "test" {
t.Fatalf("bad: %s", buf)
}
}
if err := stream.Close(); err != nil {
t.Fatalf("err: %v", err)
}
}()
go func() {
defer wg.Done()
stream, err := client.Open()
if err != nil {
t.Fatalf("err: %v", err)
}
if client.NumStreams() != 1 {
t.Fatalf("bad")
}
for i := 0; i < 1000; i++ {
n, err := stream.Write([]byte("test"))
if err != nil {
t.Fatalf("err: %v", err)
}
if n != 4 {
t.Fatalf("short write %d", n)
}
}
if err := stream.Close(); err != nil {
t.Fatalf("err: %v", err)
}
}()
doneCh := make(chan struct{})
go func() {
wg.Wait()
close(doneCh)
}()
select {
case <-doneCh:
case <-time.After(time.Second):
panic("timeout")
}
if client.NumStreams() != 0 {
t.Fatalf("bad")
}
if server.NumStreams() != 0 {
t.Fatalf("bad")
}
}
func TestSendData_Large(t *testing.T) {
client, server := testClientServer()
defer client.Close()
defer server.Close()
data := make([]byte, 512*1024)
for idx := range data {
data[idx] = byte(idx % 256)
}
wg := &sync.WaitGroup{}
wg.Add(2)
go func() {
defer wg.Done()
stream, err := server.AcceptStream()
if err != nil {
t.Fatalf("err: %v", err)
}
buf := make([]byte, 4*1024)
for i := 0; i < 128; i++ {
n, err := stream.Read(buf)
if err != nil {
t.Fatalf("err: %v", err)
}
if n != 4*1024 {
t.Fatalf("short read: %d", n)
}
for idx := range buf {
if buf[idx] != byte(idx%256) {
t.Fatalf("bad: %v %v %v", i, idx, buf[idx])
}
}
}
if err := stream.Close(); err != nil {
t.Fatalf("err: %v", err)
}
}()
go func() {
defer wg.Done()
stream, err := client.Open()
if err != nil {
t.Fatalf("err: %v", err)
}
n, err := stream.Write(data)
if err != nil {
t.Fatalf("err: %v", err)
}
if n != len(data) {
t.Fatalf("short write %d", n)
}
if err := stream.Close(); err != nil {
t.Fatalf("err: %v", err)
}
}()
doneCh := make(chan struct{})
go func() {
wg.Wait()
close(doneCh)
}()
select {
case <-doneCh:
case <-time.After(time.Second):
panic("timeout")
}
}
func TestGoAway(t *testing.T) {
client, server := testClientServer()
defer client.Close()
defer server.Close()
if err := server.GoAway(); err != nil {
t.Fatalf("err: %v", err)
}
_, err := client.Open()
if err != ErrRemoteGoAway {
t.Fatalf("err: %v", err)
}
}
func TestManyStreams(t *testing.T) {
client, server := testClientServer()
defer client.Close()
defer server.Close()
wg := &sync.WaitGroup{}
acceptor := func(i int) {
defer wg.Done()
stream, err := server.AcceptStream()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream.Close()
buf := make([]byte, 512)
for {
n, err := stream.Read(buf)
if err == io.EOF {
return
}
if err != nil {
t.Fatalf("err: %v", err)
}
if n == 0 {
t.Fatalf("err: %v", err)
}
}
}
sender := func(i int) {
defer wg.Done()
stream, err := client.Open()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream.Close()
msg := fmt.Sprintf("%08d", i)
for i := 0; i < 1000; i++ {
n, err := stream.Write([]byte(msg))
if err != nil {
t.Fatalf("err: %v", err)
}
if n != len(msg) {
t.Fatalf("short write %d", n)
}
}
}
for i := 0; i < 50; i++ {
wg.Add(2)
go acceptor(i)
go sender(i)
}
wg.Wait()
}
func TestManyStreams_PingPong(t *testing.T) {
client, server := testClientServer()
defer client.Close()
defer server.Close()
wg := &sync.WaitGroup{}
ping := []byte("ping")
pong := []byte("pong")
acceptor := func(i int) {
defer wg.Done()
stream, err := server.AcceptStream()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream.Close()
buf := make([]byte, 4)
for {
n, err := stream.Read(buf)
if err == io.EOF {
return
}
if err != nil {
t.Fatalf("err: %v", err)
}
if n != 4 {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(buf, ping) {
t.Fatalf("bad: %s", buf)
}
n, err = stream.Write(pong)
if err != nil {
t.Fatalf("err: %v", err)
}
if n != 4 {
t.Fatalf("err: %v", err)
}
}
}
sender := func(i int) {
defer wg.Done()
stream, err := client.Open()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream.Close()
buf := make([]byte, 4)
for i := 0; i < 1000; i++ {
n, err := stream.Write(ping)
if err != nil {
t.Fatalf("err: %v", err)
}
if n != 4 {
t.Fatalf("short write %d", n)
}
n, err = stream.Read(buf)
if err != nil {
t.Fatalf("err: %v", err)
}
if n != 4 {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(buf, pong) {
t.Fatalf("bad: %s", buf)
}
}
}
for i := 0; i < 50; i++ {
wg.Add(2)
go acceptor(i)
go sender(i)
}
wg.Wait()
}
func TestHalfClose(t *testing.T) {
client, server := testClientServer()
defer client.Close()
defer server.Close()
stream, err := client.Open()
if err != nil {
t.Fatalf("err: %v", err)
}
if _, err := stream.Write([]byte("a")); err != nil {
t.Fatalf("err: %v", err)
}
stream2, err := server.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
stream2.Close() // Half close
buf := make([]byte, 4)
n, err := stream2.Read(buf)
if err != nil {
t.Fatalf("err: %v", err)
}
if n != 1 {
t.Fatalf("bad: %v", n)
}
// Send more
if _, err := stream.Write([]byte("bcd")); err != nil {
t.Fatalf("err: %v", err)
}
stream.Close()
// Read after close
n, err = stream2.Read(buf)
if err != nil {
t.Fatalf("err: %v", err)
}
if n != 3 {
t.Fatalf("bad: %v", n)
}
// EOF after close
n, err = stream2.Read(buf)
if err != io.EOF {
t.Fatalf("err: %v", err)
}
if n != 0 {
t.Fatalf("bad: %v", n)
}
}
func TestReadDeadline(t *testing.T) {
client, server := testClientServer()
defer client.Close()
defer server.Close()
stream, err := client.Open()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream.Close()
stream2, err := server.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream2.Close()
if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
t.Fatalf("err: %v", err)
}
buf := make([]byte, 4)
if _, err := stream.Read(buf); err != ErrTimeout {
t.Fatalf("err: %v", err)
}
}
func TestWriteDeadline(t *testing.T) {
client, server := testClientServer()
defer client.Close()
defer server.Close()
stream, err := client.Open()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream.Close()
stream2, err := server.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream2.Close()
if err := stream.SetWriteDeadline(time.Now().Add(50 * time.Millisecond)); err != nil {
t.Fatalf("err: %v", err)
}
buf := make([]byte, 512)
for i := 0; i < int(initialStreamWindow); i++ {
_, err := stream.Write(buf)
if err != nil && err == ErrTimeout {
return
} else if err != nil {
t.Fatalf("err: %v", err)
}
}
t.Fatalf("Expected timeout")
}
func TestBacklogExceeded(t *testing.T) {
client, server := testClientServer()
defer client.Close()
defer server.Close()
// Fill the backlog
max := client.config.AcceptBacklog
for i := 0; i < max; i++ {
stream, err := client.Open()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream.Close()
if _, err := stream.Write([]byte("foo")); err != nil {
t.Fatalf("err: %v", err)
}
}
// Attempt to open a new stream
errCh := make(chan error, 1)
go func() {
_, err := client.Open()
errCh <- err
}()
// Shutdown the server
go func() {
time.Sleep(10 * time.Millisecond)
server.Close()
}()
select {
case err := <-errCh:
if err == nil {
t.Fatalf("open should fail")
}
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
func TestKeepAlive(t *testing.T) {
client, server := testClientServer()
defer client.Close()
defer server.Close()
time.Sleep(200 * time.Millisecond)
// Ping value should increase
client.pingLock.Lock()
defer client.pingLock.Unlock()
if client.pingID == 0 {
t.Fatalf("should ping")
}
server.pingLock.Lock()
defer server.pingLock.Unlock()
if server.pingID == 0 {
t.Fatalf("should ping")
}
}
func TestLargeWindow(t *testing.T) {
conf := DefaultConfig()
conf.MaxStreamWindowSize *= 2
client, server := testClientServerConfig(conf)
defer client.Close()
defer server.Close()
stream, err := client.Open()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream.Close()
stream2, err := server.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream2.Close()
stream.SetWriteDeadline(time.Now().Add(10 * time.Millisecond))
buf := make([]byte, conf.MaxStreamWindowSize)
n, err := stream.Write(buf)
if err != nil {
t.Fatalf("err: %v", err)
}
if n != len(buf) {
t.Fatalf("short write: %d", n)
}
}
type UnlimitedReader struct{}
func (u *UnlimitedReader) Read(p []byte) (int, error) {
runtime.Gosched()
return len(p), nil
}
func TestSendData_VeryLarge(t *testing.T) {
client, server := testClientServer()
defer client.Close()
defer server.Close()
var n int64 = 1 * 1024 * 1024 * 1024
var workers int = 16
wg := &sync.WaitGroup{}
wg.Add(workers * 2)
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
stream, err := server.AcceptStream()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream.Close()
buf := make([]byte, 4)
_, err = stream.Read(buf)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(buf, []byte{0, 1, 2, 3}) {
t.Fatalf("bad header")
}
recv, err := io.Copy(ioutil.Discard, stream)
if err != nil {
t.Fatalf("err: %v", err)
}
if recv != n {
t.Fatalf("bad: %v", recv)
}
}()
}
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
stream, err := client.Open()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream.Close()
_, err = stream.Write([]byte{0, 1, 2, 3})
if err != nil {
t.Fatalf("err: %v", err)
}
unlimited := &UnlimitedReader{}
sent, err := io.Copy(stream, io.LimitReader(unlimited, n))
if err != nil {
t.Fatalf("err: %v", err)
}
if sent != n {
t.Fatalf("bad: %v", sent)
}
}()
}
doneCh := make(chan struct{})
go func() {
wg.Wait()
close(doneCh)
}()
select {
case <-doneCh:
case <-time.After(20 * time.Second):
panic("timeout")
}
}
func TestBacklogExceeded_Accept(t *testing.T) {
client, server := testClientServer()
defer client.Close()
defer server.Close()
max := 5 * client.config.AcceptBacklog
go func() {
for i := 0; i < max; i++ {
stream, err := server.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream.Close()
}
}()
// Fill the backlog
for i := 0; i < max; i++ {
stream, err := client.Open()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream.Close()
if _, err := stream.Write([]byte("foo")); err != nil {
t.Fatalf("err: %v", err)
}
}
}
# Specification
We use this document to detail the internal specification of Yamux.
This is used both as a guide for implementing Yamux, but also for
alternative interoperable libraries to be built.
# Framing
Yamux uses a streaming connection underneath, but imposes a message
framing so that it can be shared between many logical streams. Each
frame contains a header like:
* Version (8 bits)
* Type (8 bits)
* Flags (16 bits)
* StreamID (32 bits)
* Length (32 bits)
This means that each header has a 12 byte overhead.
All fields are encoded in network order (big endian).
Each field is described below:
## Version Field
The version field is used for future backwards compatibily. At the
current time, the field is always set to 0, to indicate the initial
version.
## Type Field
The type field is used to switch the frame message type. The following
message types are supported:
* 0x0 Data - Used to transmit data. May transmit zero length payloads
depending on the flags.
* 0x1 Window Update - Used to updated the senders receive window size.
This is used to implement per-session flow control.
* 0x2 Ping - Used to measure RTT. It can also be used to heart-beat
and do keep-alives over TCP.
* 0x3 Go Away - Used to close a session.
## Flag Field
The flags field is used to provide additional information related
to the message type. The following flags are supported:
* 0x1 SYN - Signals the start of a new stream. May be sent with a data or
window update message. Also sent with a ping to indicate outbound.
* 0x2 ACK - Acknowledges the start of a new stream. May be sent with a data
or window update message. Also sent with a ping to indicate response.
* 0x4 FIN - Performs a half-close of a stream. May be sent with a data
message or window update.
* 0x8 RST - Reset a stream immediately. May be sent with a data or
window update message.
## StreamID Field
The StreamID field is used to identify the logical stream the frame
is addressing. The client side should use odd ID's, and the server even.
This prevents any collisions. Additionally, the 0 ID is reserved to represent
the session.
Both Ping and Go Away messages should always use the 0 StreamID.
## Length Field
The meaning of the length field depends on the message type:
* Data - provides the length of bytes following the header
* Window update - provides a delta update to the window size
* Ping - Contains an opaque value, echoed back
* Go Away - Contains an error code
# Message Flow
There is no explicit connection setup, as Yamux relies on an underlying
transport to be provided. However, there is a distinction between client
and server side of the connection.
## Opening a stream
To open a stream, an initial data or window update frame is sent
with a new StreamID. The SYN flag should be set to signal a new stream.
The receiver must then reply with either a data or window update frame
with the StreamID along with the ACK flag to accept the stream or with
the RST flag to reject the stream.
Because we are relying on the reliable stream underneath, a connection
can begin sending data once the SYN flag is sent. The corresponding
ACK does not need to be received. This is particularly well suited
for an RPC system where a client wants to open a stream and immediately
fire a request without wiating for the RTT of the ACK.
This does introduce the possibility of a connection being rejected
after data has been sent already. This is a slight semantic difference
from TCP, where the conection cannot be refused after it is opened.
Clients should be prepared to handle this by checking for an error
that indicates a RST was received.
## Closing a stream
To close a stream, either side sends a data or window update frame
along with the FIN flag. This does a half-close indicating the sender
will send no further data.
Once both sides have closed the connection, the stream is closed.
Alternatively, if an error occurs, the RST flag can be used to
hard close a stream immediately.
## Flow Control
When Yamux is initially starts each stream with a 256KB window size.
There is no window size for the session.
To prevent the streams from stalling, window update frames should be
sent regularly. Yamux can be configured to provide a larger limit for
windows sizes. Both sides assume the initial 256KB window, but can
immediately send a window update as part of the SYN/ACK indicating a
larger window.
Both sides should track the number of bytes sent in Data frames
only, as only they are tracked as part of the window size.
## Session termination
When a session is being terminated, the Go Away message should
be sent. The Length should be set to one of the following to
provide an error code:
* 0x0 Normal termination
* 0x1 Protocol error
* 0x2 Internal error
package yamux
import (
"bytes"
"io"
"sync"
"sync/atomic"
"time"
)
type streamState int
const (
streamInit streamState = iota
streamSYNSent
streamSYNReceived
streamEstablished
streamLocalClose
streamRemoteClose
streamClosed
streamReset
)
// Stream is used to represent a logical stream
// within a session.
type Stream struct {
recvWindow uint32
sendWindow uint32
id uint32
session *Session
state streamState
stateLock sync.Mutex
recvBuf bytes.Buffer
recvLock sync.Mutex
controlHdr header
controlErr chan error
controlHdrLock sync.Mutex
sendHdr header
sendErr chan error
sendLock sync.Mutex
recvNotifyCh chan struct{}
sendNotifyCh chan struct{}
readDeadline time.Time
writeDeadline time.Time
}
// newStream is used to construct a new stream within
// a given session for an ID
func newStream(session *Session, id uint32, state streamState) *Stream {
s := &Stream{
id: id,
session: session,
state: state,
controlHdr: header(make([]byte, headerSize)),
controlErr: make(chan error, 1),
sendHdr: header(make([]byte, headerSize)),
sendErr: make(chan error, 1),
recvWindow: initialStreamWindow,
sendWindow: initialStreamWindow,
recvNotifyCh: make(chan struct{}, 1),
sendNotifyCh: make(chan struct{}, 1),
}
return s
}
// Session returns the associated stream session
func (s *Stream) Session() *Session {
return s.session
}
// StreamID returns the ID of this stream
func (s *Stream) StreamID() uint32 {
return s.id
}
// Read is used to read from the stream
func (s *Stream) Read(b []byte) (n int, err error) {
defer asyncNotify(s.recvNotifyCh)
START:
s.stateLock.Lock()
switch s.state {
case streamLocalClose:
fallthrough
case streamRemoteClose:
fallthrough
case streamClosed:
if s.recvBuf.Len() == 0 {
s.stateLock.Unlock()
return 0, io.EOF
}
case streamReset:
s.stateLock.Unlock()
return 0, ErrConnectionReset
}
s.stateLock.Unlock()
// If there is no data available, block
s.recvLock.Lock()
if s.recvBuf.Len() == 0 {
s.recvLock.Unlock()
goto WAIT
}
// Read any bytes
n, _ = s.recvBuf.Read(b)
s.recvLock.Unlock()
// Send a window update potentially
err = s.sendWindowUpdate()
return n, err
WAIT:
var timeout <-chan time.Time
if !s.readDeadline.IsZero() {
delay := s.readDeadline.Sub(time.Now())
timeout = time.After(delay)
}
select {
case <-s.recvNotifyCh:
goto START
case <-timeout:
return 0, ErrTimeout
}
}
// Write is used to write to the stream
func (s *Stream) Write(b []byte) (n int, err error) {
s.sendLock.Lock()
defer s.sendLock.Unlock()
total := 0
for total < len(b) {
n, err := s.write(b[total:])
total += n
if err != nil {
return total, err
}
}
return total, nil
}
// write is used to write to the stream, may return on
// a short write.
func (s *Stream) write(b []byte) (n int, err error) {
var flags uint16
var max uint32
var body io.Reader
START:
s.stateLock.Lock()
switch s.state {
case streamLocalClose:
fallthrough
case streamClosed:
s.stateLock.Unlock()
return 0, ErrStreamClosed
case streamReset:
s.stateLock.Unlock()
return 0, ErrConnectionReset
}
s.stateLock.Unlock()
// If there is no data available, block
window := atomic.LoadUint32(&s.sendWindow)
if window == 0 {
goto WAIT
}
// Determine the flags if any
flags = s.sendFlags()
// Send up to our send window
max = min(window, uint32(len(b)))
body = bytes.NewReader(b[:max])
// Send the header
s.sendHdr.encode(typeData, flags, s.id, max)
if err := s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil {
return 0, err
}
// Reduce our send window
atomic.AddUint32(&s.sendWindow, ^uint32(max-1))
// Unlock
return int(max), err
WAIT:
var timeout <-chan time.Time
if !s.writeDeadline.IsZero() {
delay := s.writeDeadline.Sub(time.Now())
timeout = time.After(delay)
}
select {
case <-s.sendNotifyCh:
goto START
case <-timeout:
return 0, ErrTimeout
}
return 0, nil
}
// sendFlags determines any flags that are appropriate
// based on the current stream state
func (s *Stream) sendFlags() uint16 {
s.stateLock.Lock()
defer s.stateLock.Unlock()
var flags uint16
switch s.state {
case streamInit:
flags |= flagSYN
s.state = streamSYNSent
case streamSYNReceived:
flags |= flagACK
s.state = streamEstablished
}
return flags
}
// sendWindowUpdate potentially sends a window update enabling
// further writes to take place. Must be invoked with the lock.
func (s *Stream) sendWindowUpdate() error {
s.controlHdrLock.Lock()
defer s.controlHdrLock.Unlock()
// Determine the delta update
max := s.session.config.MaxStreamWindowSize
delta := max - atomic.LoadUint32(&s.recvWindow)
// Determine the flags if any
flags := s.sendFlags()
// Check if we can omit the update
if delta < (max/2) && flags == 0 {
return nil
}
// Update our window
atomic.AddUint32(&s.recvWindow, delta)
// Send the header
s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta)
if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil {
return err
}
return nil
}
// sendClose is used to send a FIN
func (s *Stream) sendClose() error {
s.controlHdrLock.Lock()
defer s.controlHdrLock.Unlock()
flags := s.sendFlags()
flags |= flagFIN
s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0)
if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil {
return err
}
return nil
}
// Close is used to close the stream
func (s *Stream) Close() error {
closeStream := false
s.stateLock.Lock()
switch s.state {
// Opened means we need to signal a close
case streamSYNSent:
fallthrough
case streamSYNReceived:
fallthrough
case streamEstablished:
s.state = streamLocalClose
goto SEND_CLOSE
case streamLocalClose:
case streamRemoteClose:
s.state = streamClosed
closeStream = true
goto SEND_CLOSE
case streamClosed:
case streamReset:
default:
panic("unhandled state")
}
s.stateLock.Unlock()
return nil
SEND_CLOSE:
s.stateLock.Unlock()
s.sendClose()
s.notifyWaiting()
if closeStream {
s.session.closeStream(s.id)
}
return nil
}
// forceClose is used for when the session is exiting
func (s *Stream) forceClose() {
s.stateLock.Lock()
s.state = streamClosed
s.stateLock.Unlock()
s.notifyWaiting()
}
// processFlags is used to update the state of the stream
// based on set flags, if any. Lock must be held
func (s *Stream) processFlags(flags uint16) error {
// Close the stream without holding the state lock
closeStream := false
defer func() {
if closeStream {
s.session.closeStream(s.id)
}
}()
s.stateLock.Lock()
defer s.stateLock.Unlock()
if flags&flagACK == flagACK {
if s.state == streamSYNSent {
s.state = streamEstablished
}
s.session.establishStream()
}
if flags&flagFIN == flagFIN {
switch s.state {
case streamSYNSent:
fallthrough
case streamSYNReceived:
fallthrough
case streamEstablished:
s.state = streamRemoteClose
s.notifyWaiting()
case streamLocalClose:
s.state = streamClosed
closeStream = true
s.notifyWaiting()
default:
s.session.logger.Printf("[ERR] yamux: unexpected FIN flag in state %d", s.state)
return ErrUnexpectedFlag
}
}
if flags&flagRST == flagRST {
if s.state == streamSYNSent {
s.session.establishStream()
}
s.state = streamReset
closeStream = true
s.notifyWaiting()
}
return nil
}
// notifyWaiting notifies all the waiting channels
func (s *Stream) notifyWaiting() {
asyncNotify(s.recvNotifyCh)
asyncNotify(s.sendNotifyCh)
}
// incrSendWindow updates the size of our send window
func (s *Stream) incrSendWindow(hdr header, flags uint16) error {
if err := s.processFlags(flags); err != nil {
return err
}
// Increase window, unblock a sender
atomic.AddUint32(&s.sendWindow, hdr.Length())
asyncNotify(s.sendNotifyCh)
return nil
}
// readData is used to handle a data frame
func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
if err := s.processFlags(flags); err != nil {
return err
}
// Check that our recv window is not exceeded
length := hdr.Length()
if length == 0 {
return nil
}
if remain := atomic.LoadUint32(&s.recvWindow); length > remain {
s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, remain, length)
return ErrRecvWindowExceeded
}
// Wrap in a limited reader
conn = &io.LimitedReader{R: conn, N: int64(length)}
// Copy into buffer
s.recvLock.Lock()
if _, err := io.Copy(&s.recvBuf, conn); err != nil {
s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err)
s.recvLock.Unlock()
return err
}
// Decrement the receive window
atomic.AddUint32(&s.recvWindow, ^uint32(length-1))
s.recvLock.Unlock()
// Unblock any readers
asyncNotify(s.recvNotifyCh)
return nil
}
// SetDeadline sets the read and write deadlines
func (s *Stream) SetDeadline(t time.Time) error {
if err := s.SetReadDeadline(t); err != nil {
return err
}
if err := s.SetWriteDeadline(t); err != nil {
return err
}
return nil
}
// SetReadDeadline sets the deadline for future Read calls.
func (s *Stream) SetReadDeadline(t time.Time) error {
s.readDeadline = t
return nil
}
// SetWriteDeadline sets the deadline for future Write calls
func (s *Stream) SetWriteDeadline(t time.Time) error {
s.writeDeadline = t
return nil
}
package yamux
// asyncSendErr is used to try an async send of an error
func asyncSendErr(ch chan error, err error) {
if ch == nil {
return
}
select {
case ch <- err:
default:
}
}
// asyncNotify is used to signal a waiting goroutine
func asyncNotify(ch chan struct{}) {
select {
case ch <- struct{}{}:
default:
}
}
// min computes the minimum of two values
func min(a, b uint32) uint32 {
if a < b {
return a
}
return b
}
package yamux
import (
"testing"
)
func TestAsyncSendErr(t *testing.T) {
ch := make(chan error)
asyncSendErr(ch, ErrTimeout)
select {
case <-ch:
t.Fatalf("should not get")
default:
}
ch = make(chan error, 1)
asyncSendErr(ch, ErrTimeout)
select {
case <-ch:
default:
t.Fatalf("should get")
}
}
func TestAsyncNotify(t *testing.T) {
ch := make(chan struct{})
asyncNotify(ch)
select {
case <-ch:
t.Fatalf("should not get")
default:
}
ch = make(chan struct{}, 1)
asyncNotify(ch)
select {
case <-ch:
default:
t.Fatalf("should get")
}
}
func TestMin(t *testing.T) {
if min(1, 2) != 1 {
t.Fatalf("bad")
}
if min(2, 1) != 1 {
t.Fatalf("bad")
}
}
# go-keyspace
This is a package extracted from go-ipfs.
Its purpose it to be used to compare a set of keys based on a given
metric. The primary metric used is XOR, as in kademlia.
package keyspace
import (
"sort"
"math/big"
)
// Key represents an identifier in a KeySpace. It holds a reference to the
// associated KeySpace, as well references to both the Original identifier,
// as well as the new, KeySpace Bytes one.
type Key struct {
// Space is the KeySpace this Key is related to.
Space KeySpace
// Original is the original value of the identifier
Original []byte
// Bytes is the new value of the identifier, in the KeySpace.
Bytes []byte
}
// Equal returns whether this key is equal to another.
func (k1 Key) Equal(k2 Key) bool {
if k1.Space != k2.Space {
panic("k1 and k2 not in same key space.")
}
return k1.Space.Equal(k1, k2)
}
// Less returns whether this key comes before another.
func (k1 Key) Less(k2 Key) bool {
if k1.Space != k2.Space {
panic("k1 and k2 not in same key space.")
}
return k1.Space.Less(k1, k2)
}
// Distance returns this key's distance to another
func (k1 Key) Distance(k2 Key) *big.Int {
if k1.Space != k2.Space {
panic("k1 and k2 not in same key space.")
}
return k1.Space.Distance(k1, k2)
}
// KeySpace is an object used to do math on identifiers. Each keyspace has its
// own properties and rules. See XorKeySpace.
type KeySpace interface {
// Key converts an identifier into a Key in this space.
Key([]byte) Key
// Equal returns whether keys are equal in this key space
Equal(Key, Key) bool
// Distance returns the distance metric in this key space
Distance(Key, Key) *big.Int
// Less returns whether the first key is smaller than the second.
Less(Key, Key) bool
}
// byDistanceToCenter is a type used to sort Keys by proximity to a center.
type byDistanceToCenter struct {
Center Key
Keys []Key
}
func (s byDistanceToCenter) Len() int {
return len(s.Keys)
}
func (s byDistanceToCenter) Swap(i, j int) {
s.Keys[i], s.Keys[j] = s.Keys[j], s.Keys[i]
}
func (s byDistanceToCenter) Less(i, j int) bool {
a := s.Center.Distance(s.Keys[i])
b := s.Center.Distance(s.Keys[j])
return a.Cmp(b) == -1
}
// SortByDistance takes a KeySpace, a center Key, and a list of Keys toSort.
// It returns a new list, where the Keys toSort have been sorted by their
// distance to the center Key.
func SortByDistance(sp KeySpace, center Key, toSort []Key) []Key {
toSortCopy := make([]Key, len(toSort))
copy(toSortCopy, toSort)
bdtc := &byDistanceToCenter{
Center: center,
Keys: toSortCopy, // copy
}
sort.Sort(bdtc)
return bdtc.Keys
}
{
"name": "go-keyspace",
"author": "joe",
"version": "1.0.0",
"language": "go",
"gx":{
"dvcsimport":"github.com/whyrusleeping/go-keyspace"
}
}
package keyspace
import (
"bytes"
"crypto/sha256"
"math/big"
)
// XORKeySpace is a KeySpace which:
// - normalizes identifiers using a cryptographic hash (sha256)
// - measures distance by XORing keys together
var XORKeySpace = &xorKeySpace{}
var _ KeySpace = XORKeySpace // ensure it conforms
type xorKeySpace struct{}
// Key converts an identifier into a Key in this space.
func (s *xorKeySpace) Key(id []byte) Key {
hash := sha256.Sum256(id)
key := hash[:]
return Key{
Space: s,
Original: id,
Bytes: key,
}
}
// Equal returns whether keys are equal in this key space
func (s *xorKeySpace) Equal(k1, k2 Key) bool {
return bytes.Equal(k1.Bytes, k2.Bytes)
}
// Distance returns the distance metric in this key space
func (s *xorKeySpace) Distance(k1, k2 Key) *big.Int {
// XOR the keys
k3 := XOR(k1.Bytes, k2.Bytes)
// interpret it as an integer
dist := big.NewInt(0).SetBytes(k3)
return dist
}
// Less returns whether the first key is smaller than the second.
func (s *xorKeySpace) Less(k1, k2 Key) bool {
a := k1.Bytes
b := k2.Bytes
for i := 0; i < len(a); i++ {
if a[i] != b[i] {
return a[i] < b[i]
}
}
return true
}
// ZeroPrefixLen returns the number of consecutive zeroes in a byte slice.
func ZeroPrefixLen(id []byte) int {
for i := 0; i < len(id); i++ {
for j := 0; j < 8; j++ {
if (id[i]>>uint8(7-j))&0x1 != 0 {
return i*8 + j
}
}
}
return len(id) * 8
}
// XOR takes two byte slices, XORs them together, returns the resulting slice.
func XOR(a, b []byte) []byte {
c := make([]byte, len(a))
for i := 0; i < len(a); i++ {
c[i] = a[i] ^ b[i]
}
return c
}
package keyspace
import (
"bytes"
"math/big"
"testing"
)
func TestPrefixLen(t *testing.T) {
cases := [][]byte{
{0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00},
{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
{0x00, 0x58, 0xFF, 0x80, 0x00, 0x00, 0xF0},
}
lens := []int{24, 56, 9}
for i, c := range cases {
r := ZeroPrefixLen(c)
if r != lens[i] {
t.Errorf("ZeroPrefixLen failed: %v != %v", r, lens[i])
}
}
}
func TestXorKeySpace(t *testing.T) {
ids := [][]byte{
{0xFF, 0xFF, 0xFF, 0xFF},
{0x00, 0x00, 0x00, 0x00},
{0xFF, 0xFF, 0xFF, 0xF0},
}
ks := [][2]Key{
{XORKeySpace.Key(ids[0]), XORKeySpace.Key(ids[0])},
{XORKeySpace.Key(ids[1]), XORKeySpace.Key(ids[1])},
{XORKeySpace.Key(ids[2]), XORKeySpace.Key(ids[2])},
}
for i, set := range ks {
if !set[0].Equal(set[1]) {
t.Errorf("Key not eq. %v != %v", set[0], set[1])
}
if !bytes.Equal(set[0].Bytes, set[1].Bytes) {
t.Errorf("Key gen failed. %v != %v", set[0].Bytes, set[1].Bytes)
}
if !bytes.Equal(set[0].Original, ids[i]) {
t.Errorf("ptrs to original. %v != %v", set[0].Original, ids[i])
}
if len(set[0].Bytes) != 32 {
t.Errorf("key length incorrect. 32 != %d", len(set[0].Bytes))
}
}
for i := 1; i < len(ks); i++ {
if ks[i][0].Less(ks[i-1][0]) == ks[i-1][0].Less(ks[i][0]) {
t.Errorf("less should be different.")
}
if ks[i][0].Distance(ks[i-1][0]).Cmp(ks[i-1][0].Distance(ks[i][0])) != 0 {
t.Errorf("distance should be the same.")
}
if ks[i][0].Equal(ks[i-1][0]) {
t.Errorf("Keys should not be eq. %v != %v", ks[i][0], ks[i-1][0])
}
}
}
func TestDistancesAndCenterSorting(t *testing.T) {
adjs := [][]byte{
{173, 149, 19, 27, 192, 183, 153, 192, 177, 175, 71, 127, 177, 79, 207, 38, 166, 169, 247, 96, 121, 228, 139, 240, 144, 172, 183, 232, 54, 123, 253, 14},
{223, 63, 97, 152, 4, 169, 47, 219, 64, 87, 25, 45, 196, 61, 215, 72, 234, 119, 138, 220, 82, 188, 73, 140, 232, 5, 36, 192, 20, 184, 17, 25},
{73, 176, 221, 176, 149, 143, 22, 42, 129, 124, 213, 114, 232, 95, 189, 154, 18, 3, 122, 132, 32, 199, 53, 185, 58, 157, 117, 78, 52, 146, 157, 127},
{73, 176, 221, 176, 149, 143, 22, 42, 129, 124, 213, 114, 232, 95, 189, 154, 18, 3, 122, 132, 32, 199, 53, 185, 58, 157, 117, 78, 52, 146, 157, 127},
{73, 176, 221, 176, 149, 143, 22, 42, 129, 124, 213, 114, 232, 95, 189, 154, 18, 3, 122, 132, 32, 199, 53, 185, 58, 157, 117, 78, 52, 146, 157, 126},
{73, 0, 221, 176, 149, 143, 22, 42, 129, 124, 213, 114, 232, 95, 189, 154, 18, 3, 122, 132, 32, 199, 53, 185, 58, 157, 117, 78, 52, 146, 157, 127},
}
keys := make([]Key, len(adjs))
for i, a := range adjs {
keys[i] = Key{Space: XORKeySpace, Bytes: a}
}
cmp := func(a int64, b *big.Int) int {
return big.NewInt(a).Cmp(b)
}
if 0 != cmp(0, keys[2].Distance(keys[3])) {
t.Errorf("distance calculation wrong: %v", keys[2].Distance(keys[3]))
}
if 0 != cmp(1, keys[2].Distance(keys[4])) {
t.Errorf("distance calculation wrong: %v", keys[2].Distance(keys[4]))
}
d1 := keys[2].Distance(keys[5])
d2 := XOR(keys[2].Bytes, keys[5].Bytes)
d2 = d2[len(keys[2].Bytes)-len(d1.Bytes()):] // skip empty space for big
if !bytes.Equal(d1.Bytes(), d2) {
t.Errorf("bytes should be the same. %v == %v", d1.Bytes(), d2)
}
if -1 != cmp(2<<32, keys[2].Distance(keys[5])) {
t.Errorf("2<<32 should be smaller")
}
keys2 := SortByDistance(XORKeySpace, keys[2], keys)
order := []int{2, 3, 4, 5, 1, 0}
for i, o := range order {
if !bytes.Equal(keys[o].Bytes, keys2[i].Bytes) {
t.Errorf("order is wrong. %d?? %v == %v", o, keys[o], keys2[i])
}
}
}
The MIT License (MIT)
Copyright (c) 2015 Jeromy Johnson
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
# go-multiaddr-filter -- CIDR netmasks with multiaddr
This module creates very simple [multiaddr](https://github.com/jbenet/go-multiaddr) formatted cidr netmasks.
It doesn't do full multiaddr parsing to save on vendoring things and perf. The `net` package will take care of verifying the validity of the network part anyway.
## Usage
```go
import filter "github.com/whyrusleeping/multiaddr-filter"
filter.NewMask("/ip4/192.168.0.0/24") // ipv4
filter.NewMask("/ip6/fe80::/64") // ipv6
```
package mask
import (
"errors"
"fmt"
"net"
"strings"
manet "gx/QmNT7d1e4Xcp3KcsvxyzUHVtqrR43uypoxLLzdKj6YZga2/go-multiaddr-net"
)
var ErrInvalidFormat = errors.New("invalid multiaddr-filter format")
func NewMask(a string) (*net.IPNet, error) {
parts := strings.Split(a, "/")
if parts[0] != "" {
return nil, ErrInvalidFormat
}
if len(parts) != 5 {
return nil, ErrInvalidFormat
}
// check it's a valid filter address. ip + cidr
isip := parts[1] == "ip4" || parts[1] == "ip6"
iscidr := parts[3] == "ipcidr"
if !isip || !iscidr {
return nil, ErrInvalidFormat
}
_, ipn, err := net.ParseCIDR(parts[2] + "/" + parts[4])
if err != nil {
return nil, err
}
return ipn, nil
}
func ConvertIPNet(n *net.IPNet) (string, error) {
addr, err := manet.FromIP(n.IP)
if err != nil {
return "", err
}
b, _ := n.Mask.Size()
return fmt.Sprintf("%s/ipcidr/%d", addr, b), nil
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment