2013-12-08 21:20:27 -05:00
|
|
|
package rpc
|
|
|
|
|
|
|
|
import (
|
|
|
|
"encoding/binary"
|
|
|
|
"fmt"
|
|
|
|
"io"
|
|
|
|
"log"
|
|
|
|
"sync"
|
2013-12-09 17:24:55 -05:00
|
|
|
"time"
|
2013-12-08 21:20:27 -05:00
|
|
|
)
|
|
|
|
|
|
|
|
// MuxConn is a connection that can be used bi-directionally for RPC. Normally,
|
|
|
|
// Go RPC only allows client-to-server connections. This allows the client
|
|
|
|
// to actually act as a server as well.
|
|
|
|
//
|
|
|
|
// MuxConn works using a fairly dumb multiplexing technique of simply
|
2013-12-09 17:29:28 -05:00
|
|
|
// framing every piece of data sent into a prefix + data format. Streams
|
|
|
|
// are established using a subset of the TCP protocol. Only a subset is
|
|
|
|
// necessary since we assume ordering on the underlying RWC.
|
2013-12-08 21:20:27 -05:00
|
|
|
type MuxConn struct {
|
2013-12-09 19:27:13 -05:00
|
|
|
curId uint32
|
2013-12-08 21:20:27 -05:00
|
|
|
rwc io.ReadWriteCloser
|
2013-12-09 19:27:13 -05:00
|
|
|
streams map[uint32]*Stream
|
2013-12-10 14:40:17 -05:00
|
|
|
mu sync.RWMutex
|
2013-12-08 21:20:27 -05:00
|
|
|
wlock sync.Mutex
|
|
|
|
}
|
|
|
|
|
2013-12-09 17:24:55 -05:00
|
|
|
type muxPacketType byte
|
|
|
|
|
|
|
|
const (
|
|
|
|
muxPacketSyn muxPacketType = iota
|
|
|
|
muxPacketAck
|
|
|
|
muxPacketFin
|
|
|
|
muxPacketData
|
|
|
|
)
|
|
|
|
|
2013-12-08 21:20:27 -05:00
|
|
|
func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn {
|
|
|
|
m := &MuxConn{
|
|
|
|
rwc: rwc,
|
2013-12-09 19:27:13 -05:00
|
|
|
streams: make(map[uint32]*Stream),
|
2013-12-08 21:20:27 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
go m.loop()
|
|
|
|
|
|
|
|
return m
|
|
|
|
}
|
|
|
|
|
2013-12-08 21:30:29 -05:00
|
|
|
// Close closes the underlying io.ReadWriteCloser. This will also close
|
|
|
|
// all streams that are open.
|
|
|
|
func (m *MuxConn) Close() error {
|
2013-12-10 14:40:17 -05:00
|
|
|
m.mu.RLock()
|
|
|
|
defer m.mu.RUnlock()
|
2013-12-08 21:30:29 -05:00
|
|
|
|
|
|
|
// Close all the streams
|
|
|
|
for _, w := range m.streams {
|
|
|
|
w.Close()
|
|
|
|
}
|
2013-12-09 19:27:13 -05:00
|
|
|
m.streams = make(map[uint32]*Stream)
|
2013-12-08 21:30:29 -05:00
|
|
|
|
|
|
|
return m.rwc.Close()
|
|
|
|
}
|
|
|
|
|
2013-12-09 17:24:55 -05:00
|
|
|
// Accept accepts a multiplexed connection with the given ID. This
|
|
|
|
// will block until a request is made to connect.
|
2013-12-09 19:27:13 -05:00
|
|
|
func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) {
|
2013-12-09 17:24:55 -05:00
|
|
|
stream, err := m.openStream(id)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
// If the stream isn't closed, then it is already open somehow
|
|
|
|
stream.mu.Lock()
|
|
|
|
if stream.state != streamStateSynRecv && stream.state != streamStateClosed {
|
|
|
|
stream.mu.Unlock()
|
2013-12-10 13:34:35 -05:00
|
|
|
return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state)
|
2013-12-09 17:24:55 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
if stream.state == streamStateSynRecv {
|
|
|
|
// Fast track establishing since we already got the syn
|
|
|
|
stream.setState(streamStateEstablished)
|
|
|
|
stream.mu.Unlock()
|
|
|
|
}
|
|
|
|
|
|
|
|
if stream.state != streamStateEstablished {
|
|
|
|
// Go into the listening state
|
|
|
|
stream.setState(streamStateListen)
|
|
|
|
stream.mu.Unlock()
|
|
|
|
|
|
|
|
// Wait for the connection to establish
|
|
|
|
ACCEPT_ESTABLISH_LOOP:
|
|
|
|
for {
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
stream.mu.Lock()
|
|
|
|
switch stream.state {
|
|
|
|
case streamStateListen:
|
|
|
|
stream.mu.Unlock()
|
2013-12-10 14:40:17 -05:00
|
|
|
case streamStateClosed:
|
|
|
|
// This can happen if it becomes established, some data is sent,
|
|
|
|
// and it closed all within the time period we wait above.
|
|
|
|
// This case will be fixed when we have edge-triggered checks.
|
|
|
|
fallthrough
|
2013-12-09 17:24:55 -05:00
|
|
|
case streamStateEstablished:
|
|
|
|
stream.mu.Unlock()
|
|
|
|
break ACCEPT_ESTABLISH_LOOP
|
|
|
|
default:
|
|
|
|
defer stream.mu.Unlock()
|
2013-12-10 14:40:17 -05:00
|
|
|
return nil, fmt.Errorf("Stream %d went to bad state: %d", id, stream.state)
|
2013-12-09 17:24:55 -05:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Send the ack down
|
|
|
|
if _, err := m.write(stream.id, muxPacketAck, nil); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return stream, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// Dial opens a connection to the remote end using the given stream ID.
|
|
|
|
// An Accept on the remote end will only work with if the IDs match.
|
2013-12-09 19:27:13 -05:00
|
|
|
func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) {
|
2013-12-09 17:24:55 -05:00
|
|
|
stream, err := m.openStream(id)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
// If the stream isn't closed, then it is already open somehow
|
|
|
|
stream.mu.Lock()
|
|
|
|
if stream.state != streamStateClosed {
|
|
|
|
stream.mu.Unlock()
|
2013-12-10 13:34:35 -05:00
|
|
|
return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state)
|
2013-12-09 17:24:55 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
// Open a connection
|
|
|
|
if _, err := m.write(stream.id, muxPacketSyn, nil); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
stream.setState(streamStateSynSent)
|
|
|
|
stream.mu.Unlock()
|
|
|
|
|
|
|
|
for {
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
stream.mu.Lock()
|
|
|
|
switch stream.state {
|
|
|
|
case streamStateSynSent:
|
|
|
|
stream.mu.Unlock()
|
2013-12-10 14:40:17 -05:00
|
|
|
case streamStateClosed:
|
|
|
|
// This can happen if it becomes established, some data is sent,
|
|
|
|
// and it closed all within the time period we wait above.
|
|
|
|
// This case will be fixed when we have edge-triggered checks.
|
|
|
|
fallthrough
|
2013-12-09 17:24:55 -05:00
|
|
|
case streamStateEstablished:
|
|
|
|
stream.mu.Unlock()
|
|
|
|
return stream, nil
|
|
|
|
default:
|
|
|
|
defer stream.mu.Unlock()
|
2013-12-10 14:40:17 -05:00
|
|
|
return nil, fmt.Errorf("Stream %d went to bad state: %d", id, stream.state)
|
2013-12-09 17:24:55 -05:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2013-12-09 19:27:13 -05:00
|
|
|
// NextId returns the next available stream ID that isn't currently
|
|
|
|
// taken.
|
|
|
|
func (m *MuxConn) NextId() uint32 {
|
|
|
|
m.mu.Lock()
|
|
|
|
defer m.mu.Unlock()
|
|
|
|
|
|
|
|
for {
|
2013-12-10 13:34:35 -05:00
|
|
|
result := m.curId
|
2013-12-09 19:27:13 -05:00
|
|
|
m.curId++
|
2013-12-10 13:34:35 -05:00
|
|
|
if _, ok := m.streams[result]; !ok {
|
|
|
|
return result
|
|
|
|
}
|
2013-12-09 19:27:13 -05:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *MuxConn) openStream(id uint32) (*Stream, error) {
|
2013-12-10 14:40:17 -05:00
|
|
|
// First grab a read-lock if we have the stream already we can
|
|
|
|
// cheaply return it.
|
|
|
|
m.mu.RLock()
|
|
|
|
if stream, ok := m.streams[id]; ok {
|
|
|
|
m.mu.RUnlock()
|
|
|
|
return stream, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// Now acquire a full blown write lock so we can create the stream
|
|
|
|
m.mu.RUnlock()
|
2013-12-08 21:20:27 -05:00
|
|
|
m.mu.Lock()
|
2013-12-09 17:24:55 -05:00
|
|
|
defer m.mu.Unlock()
|
2013-12-08 21:20:27 -05:00
|
|
|
|
2013-12-10 14:40:17 -05:00
|
|
|
// We have to check this again because there is a time period
|
|
|
|
// above where we couldn't lost this lock.
|
2013-12-09 17:24:55 -05:00
|
|
|
if stream, ok := m.streams[id]; ok {
|
|
|
|
return stream, nil
|
2013-12-08 21:20:27 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
// Create the stream object and channel where data will be sent to
|
|
|
|
dataR, dataW := io.Pipe()
|
2013-12-10 13:44:57 -05:00
|
|
|
writeCh := make(chan []byte, 10)
|
2013-12-08 21:39:14 -05:00
|
|
|
|
|
|
|
// Set the data channel so we can write to it.
|
2013-12-08 21:20:27 -05:00
|
|
|
stream := &Stream{
|
2013-12-10 13:44:57 -05:00
|
|
|
id: id,
|
|
|
|
mux: m,
|
|
|
|
reader: dataR,
|
|
|
|
writeCh: writeCh,
|
2013-12-08 21:20:27 -05:00
|
|
|
}
|
2013-12-09 17:24:55 -05:00
|
|
|
stream.setState(streamStateClosed)
|
2013-12-08 21:20:27 -05:00
|
|
|
|
2013-12-10 13:44:57 -05:00
|
|
|
// Start the goroutine that will read from the queue and write
|
|
|
|
// data out.
|
|
|
|
go func() {
|
2013-12-10 14:40:17 -05:00
|
|
|
defer dataW.Close()
|
|
|
|
|
2013-12-10 13:44:57 -05:00
|
|
|
for {
|
|
|
|
data := <-writeCh
|
2013-12-10 14:40:17 -05:00
|
|
|
if data == nil {
|
|
|
|
// A nil is a tombstone letting us know we're done
|
|
|
|
// accepting data.
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2013-12-10 13:44:57 -05:00
|
|
|
if _, err := dataW.Write(data); err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
|
2013-12-09 17:24:55 -05:00
|
|
|
m.streams[id] = stream
|
|
|
|
return m.streams[id], nil
|
2013-12-08 21:20:27 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
func (m *MuxConn) loop() {
|
2013-12-08 21:30:29 -05:00
|
|
|
defer m.Close()
|
|
|
|
|
2013-12-09 19:27:13 -05:00
|
|
|
var id uint32
|
2013-12-09 17:24:55 -05:00
|
|
|
var packetType muxPacketType
|
|
|
|
var length int32
|
2013-12-08 21:20:27 -05:00
|
|
|
for {
|
|
|
|
if err := binary.Read(m.rwc, binary.BigEndian, &id); err != nil {
|
|
|
|
log.Printf("[ERR] Error reading stream ID: %s", err)
|
|
|
|
return
|
|
|
|
}
|
2013-12-09 17:24:55 -05:00
|
|
|
if err := binary.Read(m.rwc, binary.BigEndian, &packetType); err != nil {
|
|
|
|
log.Printf("[ERR] Error reading packet type: %s", err)
|
|
|
|
return
|
|
|
|
}
|
2013-12-08 21:20:27 -05:00
|
|
|
if err := binary.Read(m.rwc, binary.BigEndian, &length); err != nil {
|
|
|
|
log.Printf("[ERR] Error reading length: %s", err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// TODO(mitchellh): probably would be better to re-use a buffer...
|
|
|
|
data := make([]byte, length)
|
2013-12-09 17:24:55 -05:00
|
|
|
if length > 0 {
|
|
|
|
if _, err := m.rwc.Read(data); err != nil {
|
|
|
|
log.Printf("[ERR] Error reading data: %s", err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
stream, err := m.openStream(id)
|
|
|
|
if err != nil {
|
|
|
|
log.Printf("[ERR] Error opening stream %d: %s", id, err)
|
2013-12-08 21:20:27 -05:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2013-12-10 14:43:02 -05:00
|
|
|
//log.Printf("[DEBUG] Stream %d received packet %d", id, packetType)
|
2013-12-09 17:24:55 -05:00
|
|
|
switch packetType {
|
|
|
|
case muxPacketAck:
|
|
|
|
stream.mu.Lock()
|
2013-12-10 14:40:17 -05:00
|
|
|
switch stream.state {
|
|
|
|
case streamStateSynSent:
|
2013-12-09 17:24:55 -05:00
|
|
|
stream.setState(streamStateEstablished)
|
2013-12-10 14:40:17 -05:00
|
|
|
case streamStateFinWait1:
|
|
|
|
stream.remoteClose()
|
|
|
|
default:
|
2013-12-09 17:24:55 -05:00
|
|
|
log.Printf("[ERR] Ack received for stream in state: %d", stream.state)
|
|
|
|
}
|
|
|
|
stream.mu.Unlock()
|
|
|
|
case muxPacketSyn:
|
|
|
|
stream.mu.Lock()
|
|
|
|
switch stream.state {
|
|
|
|
case streamStateClosed:
|
|
|
|
stream.setState(streamStateSynRecv)
|
|
|
|
case streamStateListen:
|
|
|
|
stream.setState(streamStateEstablished)
|
|
|
|
default:
|
|
|
|
log.Printf("[ERR] Syn received for stream in state: %d", stream.state)
|
|
|
|
}
|
|
|
|
stream.mu.Unlock()
|
|
|
|
case muxPacketFin:
|
|
|
|
stream.mu.Lock()
|
2013-12-10 14:40:17 -05:00
|
|
|
switch stream.state {
|
|
|
|
case streamStateEstablished:
|
|
|
|
m.write(id, muxPacketAck, nil)
|
|
|
|
fallthrough
|
|
|
|
case streamStateFinWait1:
|
|
|
|
stream.remoteClose()
|
|
|
|
|
|
|
|
// Remove this stream from being active so that it
|
|
|
|
// can be re-used
|
|
|
|
m.mu.Lock()
|
|
|
|
delete(m.streams, stream.id)
|
|
|
|
m.mu.Unlock()
|
|
|
|
default:
|
|
|
|
log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state)
|
|
|
|
}
|
2013-12-09 17:24:55 -05:00
|
|
|
stream.mu.Unlock()
|
|
|
|
|
|
|
|
case muxPacketData:
|
|
|
|
stream.mu.Lock()
|
|
|
|
if stream.state == streamStateEstablished {
|
2013-12-10 13:44:57 -05:00
|
|
|
select {
|
|
|
|
case stream.writeCh <- data:
|
|
|
|
default:
|
|
|
|
log.Printf("[ERR] Failed to write data, buffer full: %d", id)
|
|
|
|
}
|
2013-12-09 17:24:55 -05:00
|
|
|
} else {
|
|
|
|
log.Printf("[ERR] Data received for stream in state: %d", stream.state)
|
|
|
|
}
|
|
|
|
stream.mu.Unlock()
|
2013-12-08 21:20:27 -05:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2013-12-09 19:27:13 -05:00
|
|
|
func (m *MuxConn) write(id uint32, dataType muxPacketType, p []byte) (int, error) {
|
2013-12-08 21:20:27 -05:00
|
|
|
m.wlock.Lock()
|
|
|
|
defer m.wlock.Unlock()
|
|
|
|
|
|
|
|
if err := binary.Write(m.rwc, binary.BigEndian, id); err != nil {
|
|
|
|
return 0, err
|
|
|
|
}
|
2013-12-09 17:24:55 -05:00
|
|
|
if err := binary.Write(m.rwc, binary.BigEndian, byte(dataType)); err != nil {
|
|
|
|
return 0, err
|
|
|
|
}
|
2013-12-08 21:20:27 -05:00
|
|
|
if err := binary.Write(m.rwc, binary.BigEndian, int32(len(p))); err != nil {
|
|
|
|
return 0, err
|
|
|
|
}
|
2013-12-09 17:24:55 -05:00
|
|
|
if len(p) == 0 {
|
|
|
|
return 0, nil
|
|
|
|
}
|
2013-12-08 21:20:27 -05:00
|
|
|
return m.rwc.Write(p)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Stream is a single stream of data and implements io.ReadWriteCloser
|
|
|
|
type Stream struct {
|
2013-12-09 19:27:13 -05:00
|
|
|
id uint32
|
2013-12-09 17:24:55 -05:00
|
|
|
mux *MuxConn
|
|
|
|
reader io.Reader
|
|
|
|
state streamState
|
|
|
|
stateUpdated time.Time
|
|
|
|
mu sync.Mutex
|
2013-12-10 13:44:57 -05:00
|
|
|
writeCh chan<- []byte
|
2013-12-08 21:20:27 -05:00
|
|
|
}
|
|
|
|
|
2013-12-09 17:24:55 -05:00
|
|
|
type streamState byte
|
|
|
|
|
|
|
|
const (
|
|
|
|
streamStateClosed streamState = iota
|
|
|
|
streamStateListen
|
|
|
|
streamStateSynRecv
|
|
|
|
streamStateSynSent
|
|
|
|
streamStateEstablished
|
2013-12-10 14:40:17 -05:00
|
|
|
streamStateFinWait1
|
2013-12-09 17:24:55 -05:00
|
|
|
)
|
|
|
|
|
2013-12-08 21:20:27 -05:00
|
|
|
func (s *Stream) Close() error {
|
2013-12-09 17:24:55 -05:00
|
|
|
s.mu.Lock()
|
|
|
|
if s.state != streamStateEstablished {
|
2013-12-10 14:40:17 -05:00
|
|
|
s.mu.Unlock()
|
2013-12-09 17:24:55 -05:00
|
|
|
return fmt.Errorf("Stream in bad state: %d", s.state)
|
|
|
|
}
|
|
|
|
|
|
|
|
if _, err := s.mux.write(s.id, muxPacketFin, nil); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2013-12-10 14:40:17 -05:00
|
|
|
s.setState(streamStateFinWait1)
|
|
|
|
s.mu.Unlock()
|
|
|
|
|
|
|
|
for {
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
s.mu.Lock()
|
|
|
|
switch s.state {
|
|
|
|
case streamStateFinWait1:
|
|
|
|
s.mu.Unlock()
|
|
|
|
case streamStateClosed:
|
|
|
|
s.mu.Unlock()
|
|
|
|
return nil
|
|
|
|
default:
|
|
|
|
defer s.mu.Unlock()
|
|
|
|
return fmt.Errorf("Stream %d went to bad state: %d", s.id, s.state)
|
|
|
|
}
|
|
|
|
}
|
2013-12-09 17:24:55 -05:00
|
|
|
|
2013-12-08 21:20:27 -05:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Stream) Read(p []byte) (int, error) {
|
|
|
|
return s.reader.Read(p)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Stream) Write(p []byte) (int, error) {
|
2013-12-10 14:40:17 -05:00
|
|
|
s.mu.Lock()
|
|
|
|
state := s.state
|
|
|
|
s.mu.Unlock()
|
|
|
|
|
|
|
|
if state != streamStateEstablished {
|
2013-12-10 17:11:50 -05:00
|
|
|
return 0, fmt.Errorf("Stream %d in bad state to send: %d", s.id, state)
|
2013-12-10 14:40:17 -05:00
|
|
|
}
|
|
|
|
|
2013-12-09 17:24:55 -05:00
|
|
|
return s.mux.write(s.id, muxPacketData, p)
|
|
|
|
}
|
|
|
|
|
2013-12-10 14:40:17 -05:00
|
|
|
func (s *Stream) remoteClose() {
|
|
|
|
s.setState(streamStateClosed)
|
|
|
|
s.writeCh <- nil
|
|
|
|
}
|
|
|
|
|
2013-12-09 17:24:55 -05:00
|
|
|
func (s *Stream) setState(state streamState) {
|
|
|
|
s.state = state
|
|
|
|
s.stateUpdated = time.Now().UTC()
|
2013-12-08 21:20:27 -05:00
|
|
|
}
|