packer-cn/packer/rpc/muxconn.go

484 lines
11 KiB
Go
Raw Normal View History

2013-12-08 21:20:27 -05:00
package rpc
import (
"encoding/binary"
"fmt"
"io"
"log"
"sync"
"time"
2013-12-08 21:20:27 -05:00
)
2013-12-10 20:31:54 -05:00
// MuxConn is able to multiplex multiple streams on top of any
// io.ReadWriteCloser. These streams act like TCP connections (Dial, Accept,
// Close, full duplex, etc.).
//
// The underlying io.ReadWriteCloser is expected to guarantee delivery
// and ordering, such as TCP. Congestion control and such aren't implemented
// by the streams, so that is also up to the underlying connection.
2013-12-08 21:20:27 -05:00
//
// MuxConn works using a fairly dumb multiplexing technique of simply
2013-12-09 17:29:28 -05:00
// framing every piece of data sent into a prefix + data format. Streams
// are established using a subset of the TCP protocol. Only a subset is
// necessary since we assume ordering on the underlying RWC.
2013-12-08 21:20:27 -05:00
type MuxConn struct {
curId uint32
2013-12-08 21:20:27 -05:00
rwc io.ReadWriteCloser
streams map[uint32]*Stream
2013-12-10 14:40:17 -05:00
mu sync.RWMutex
2013-12-08 21:20:27 -05:00
wlock sync.Mutex
doneCh chan struct{}
2013-12-08 21:20:27 -05:00
}
type muxPacketType byte
const (
muxPacketSyn muxPacketType = iota
muxPacketSynAck
muxPacketAck
muxPacketFin
muxPacketData
)
2013-12-10 20:31:54 -05:00
// Create a new MuxConn around any io.ReadWriteCloser.
2013-12-08 21:20:27 -05:00
func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn {
m := &MuxConn{
rwc: rwc,
streams: make(map[uint32]*Stream),
doneCh: make(chan struct{}),
2013-12-08 21:20:27 -05:00
}
go m.cleaner()
2013-12-08 21:20:27 -05:00
go m.loop()
return m
}
// Close closes the underlying io.ReadWriteCloser. This will also close
// all streams that are open.
func (m *MuxConn) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
// Close all the streams
for _, w := range m.streams {
w.Close()
}
m.streams = make(map[uint32]*Stream)
2013-12-10 20:31:54 -05:00
// Close the actual connection. This will also force the loop
// to end since it'll read EOF or closed connection.
return m.rwc.Close()
}
// Accept accepts a multiplexed connection with the given ID. This
// will block until a request is made to connect.
func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) {
stream, err := m.openStream(id)
if err != nil {
return nil, err
}
// If the stream isn't closed, then it is already open somehow
stream.mu.Lock()
defer stream.mu.Unlock()
if stream.state != streamStateSynRecv && stream.state != streamStateClosed {
return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state)
}
if stream.state == streamStateClosed {
// Go into the listening state and wait for a syn
stream.setState(streamStateListen)
if err := stream.waitState(streamStateSynRecv); err != nil {
return nil, err
}
}
if stream.state == streamStateSynRecv {
// Send a syn-ack
if _, err := m.write(stream.id, muxPacketSynAck, nil); err != nil {
return nil, err
}
}
if err := stream.waitState(streamStateEstablished); err != nil {
return nil, err
}
return stream, nil
}
// Dial opens a connection to the remote end using the given stream ID.
// An Accept on the remote end will only work with if the IDs match.
func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) {
stream, err := m.openStream(id)
if err != nil {
return nil, err
}
// If the stream isn't closed, then it is already open somehow
stream.mu.Lock()
defer stream.mu.Unlock()
if stream.state != streamStateClosed {
return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state)
}
// Open a connection
if _, err := m.write(stream.id, muxPacketSyn, nil); err != nil {
return nil, err
}
stream.setState(streamStateSynSent)
if err := stream.waitState(streamStateEstablished); err != nil {
return nil, err
}
m.write(id, muxPacketAck, nil)
return stream, nil
}
// NextId returns the next available stream ID that isn't currently
// taken.
func (m *MuxConn) NextId() uint32 {
m.mu.Lock()
defer m.mu.Unlock()
for {
result := m.curId
m.curId++
if _, ok := m.streams[result]; !ok {
return result
}
}
}
func (m *MuxConn) openStream(id uint32) (*Stream, error) {
2013-12-10 14:40:17 -05:00
// First grab a read-lock if we have the stream already we can
// cheaply return it.
m.mu.RLock()
if stream, ok := m.streams[id]; ok {
m.mu.RUnlock()
return stream, nil
}
// Now acquire a full blown write lock so we can create the stream
m.mu.RUnlock()
2013-12-08 21:20:27 -05:00
m.mu.Lock()
defer m.mu.Unlock()
2013-12-08 21:20:27 -05:00
// Make sure we attempt to use the next biggest stream ID
if id >= m.curId {
m.curId = id + 1
}
2013-12-10 14:40:17 -05:00
// We have to check this again because there is a time period
// above where we couldn't lost this lock.
if stream, ok := m.streams[id]; ok {
return stream, nil
2013-12-08 21:20:27 -05:00
}
// Create the stream object and channel where data will be sent to
dataR, dataW := io.Pipe()
writeCh := make(chan []byte, 256)
// Set the data channel so we can write to it.
2013-12-08 21:20:27 -05:00
stream := &Stream{
id: id,
mux: m,
reader: dataR,
writeCh: writeCh,
stateChange: make(map[chan<- streamState]struct{}),
2013-12-08 21:20:27 -05:00
}
stream.setState(streamStateClosed)
2013-12-08 21:20:27 -05:00
// Start the goroutine that will read from the queue and write
// data out.
go func() {
2013-12-10 14:40:17 -05:00
defer dataW.Close()
for {
data := <-writeCh
2013-12-10 14:40:17 -05:00
if data == nil {
// A nil is a tombstone letting us know we're done
// accepting data.
return
}
if _, err := dataW.Write(data); err != nil {
return
}
}
}()
m.streams[id] = stream
return m.streams[id], nil
2013-12-08 21:20:27 -05:00
}
func (m *MuxConn) cleaner() {
for {
done := false
select {
case <-time.After(500 * time.Millisecond):
case <-m.doneCh:
done = true
}
m.mu.Lock()
for id, s := range m.streams {
s.mu.Lock()
if s.state == streamStateClosed {
delete(m.streams, id)
}
s.mu.Unlock()
}
if done {
for _, s := range m.streams {
s.mu.Lock()
s.closeWriter()
s.mu.Unlock()
}
}
m.mu.Unlock()
if done {
return
}
}
}
2013-12-08 21:20:27 -05:00
func (m *MuxConn) loop() {
2013-12-10 20:31:54 -05:00
// Force close every stream that we know about when we exit so
// that they all read EOF and don't block forever.
defer func() {
log.Printf("[INFO] Mux connection loop exiting")
close(m.doneCh)
}()
var id uint32
var packetType muxPacketType
var length int32
2013-12-08 21:20:27 -05:00
for {
if err := binary.Read(m.rwc, binary.BigEndian, &id); err != nil {
log.Printf("[ERR] Error reading stream ID: %s", err)
return
}
if err := binary.Read(m.rwc, binary.BigEndian, &packetType); err != nil {
log.Printf("[ERR] Error reading packet type: %s", err)
return
}
2013-12-08 21:20:27 -05:00
if err := binary.Read(m.rwc, binary.BigEndian, &length); err != nil {
log.Printf("[ERR] Error reading length: %s", err)
return
}
// TODO(mitchellh): probably would be better to re-use a buffer...
data := make([]byte, length)
if length > 0 {
if _, err := m.rwc.Read(data); err != nil {
log.Printf("[ERR] Error reading data: %s", err)
return
}
}
stream, err := m.openStream(id)
if err != nil {
log.Printf("[ERR] Error opening stream %d: %s", id, err)
2013-12-08 21:20:27 -05:00
return
}
//log.Printf("[TRACE] Stream %d received packet %d", id, packetType)
switch packetType {
case muxPacketSyn:
stream.mu.Lock()
switch stream.state {
case streamStateClosed:
fallthrough
case streamStateListen:
stream.setState(streamStateSynRecv)
default:
log.Printf("[ERR] Syn received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
case muxPacketAck:
stream.mu.Lock()
2013-12-10 14:40:17 -05:00
switch stream.state {
case streamStateSynRecv:
stream.setState(streamStateEstablished)
2013-12-10 14:40:17 -05:00
case streamStateFinWait1:
stream.setState(streamStateFinWait2)
case streamStateLastAck:
stream.closeWriter()
fallthrough
case streamStateClosing:
stream.setState(streamStateClosed)
2013-12-10 14:40:17 -05:00
default:
log.Printf("[ERR] Ack received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
case muxPacketSynAck:
stream.mu.Lock()
switch stream.state {
case streamStateSynSent:
stream.setState(streamStateEstablished)
default:
log.Printf("[ERR] SynAck received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
case muxPacketFin:
stream.mu.Lock()
2013-12-10 14:40:17 -05:00
switch stream.state {
case streamStateEstablished:
stream.closeWriter()
stream.setState(streamStateCloseWait)
2013-12-10 14:40:17 -05:00
m.write(id, muxPacketAck, nil)
case streamStateFinWait2:
stream.closeWriter()
stream.setState(streamStateClosed)
m.write(id, muxPacketAck, nil)
case streamStateFinWait1:
stream.closeWriter()
stream.setState(streamStateClosing)
m.write(id, muxPacketAck, nil)
2013-12-10 14:40:17 -05:00
default:
log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state)
}
stream.mu.Unlock()
case muxPacketData:
stream.mu.Lock()
switch stream.state {
case streamStateFinWait1:
fallthrough
case streamStateFinWait2:
fallthrough
case streamStateEstablished:
select {
case stream.writeCh <- data:
default:
panic(fmt.Sprintf("Failed to write data, buffer full for stream %d", id))
}
default:
log.Printf("[ERR] Data received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
2013-12-08 21:20:27 -05:00
}
}
}
func (m *MuxConn) write(id uint32, dataType muxPacketType, p []byte) (int, error) {
2013-12-08 21:20:27 -05:00
m.wlock.Lock()
defer m.wlock.Unlock()
if err := binary.Write(m.rwc, binary.BigEndian, id); err != nil {
return 0, err
}
if err := binary.Write(m.rwc, binary.BigEndian, byte(dataType)); err != nil {
return 0, err
}
2013-12-08 21:20:27 -05:00
if err := binary.Write(m.rwc, binary.BigEndian, int32(len(p))); err != nil {
return 0, err
}
if len(p) == 0 {
return 0, nil
}
2013-12-08 21:20:27 -05:00
return m.rwc.Write(p)
}
2013-12-10 20:31:54 -05:00
// Stream is a single stream of data and implements io.ReadWriteCloser.
// A Stream is full-duplex so you can write data as well as read data.
2013-12-08 21:20:27 -05:00
type Stream struct {
id uint32
mux *MuxConn
reader io.Reader
state streamState
stateChange map[chan<- streamState]struct{}
stateUpdated time.Time
mu sync.Mutex
writeCh chan<- []byte
2013-12-08 21:20:27 -05:00
}
type streamState byte
const (
streamStateClosed streamState = iota
streamStateListen
streamStateSynRecv
streamStateSynSent
streamStateEstablished
2013-12-10 14:40:17 -05:00
streamStateFinWait1
streamStateFinWait2
streamStateCloseWait
streamStateClosing
streamStateLastAck
)
2013-12-08 21:20:27 -05:00
func (s *Stream) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.state != streamStateEstablished && s.state != streamStateCloseWait {
return fmt.Errorf("Stream in bad state: %d", s.state)
}
if s.state == streamStateEstablished {
s.setState(streamStateFinWait1)
} else {
s.setState(streamStateLastAck)
2013-12-10 14:40:17 -05:00
}
s.mux.write(s.id, muxPacketFin, nil)
2013-12-08 21:20:27 -05:00
return nil
}
func (s *Stream) Read(p []byte) (int, error) {
return s.reader.Read(p)
}
func (s *Stream) Write(p []byte) (int, error) {
2013-12-10 14:40:17 -05:00
s.mu.Lock()
state := s.state
s.mu.Unlock()
if state != streamStateEstablished && state != streamStateCloseWait {
return 0, fmt.Errorf("Stream %d in bad state to send: %d", s.id, state)
2013-12-10 14:40:17 -05:00
}
return s.mux.write(s.id, muxPacketData, p)
}
func (s *Stream) closeWriter() {
2013-12-10 14:40:17 -05:00
s.writeCh <- nil
}
func (s *Stream) setState(state streamState) {
//log.Printf("[TRACE] Stream %d went to state %d", s.id, state)
s.state = state
s.stateUpdated = time.Now().UTC()
for ch, _ := range s.stateChange {
select {
case ch <- state:
default:
}
}
2013-12-08 21:20:27 -05:00
}
func (s *Stream) waitState(target streamState) error {
// Register a state change listener to wait for changes
stateCh := make(chan streamState, 10)
s.stateChange[stateCh] = struct{}{}
s.mu.Unlock()
defer func() {
s.mu.Lock()
delete(s.stateChange, stateCh)
}()
state := <-stateCh
if state == target {
return nil
} else {
return fmt.Errorf("Stream %d went to bad state: %d", s.id, state)
}
}