From ec68a3fd39f15e9b8f1a92b620c12ad8f71414f6 Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Mon, 9 Dec 2013 16:27:13 -0800 Subject: [PATCH] packer/rpc: MuxConn can return next available stream ID --- packer/rpc/muxconn.go | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/packer/rpc/muxconn.go b/packer/rpc/muxconn.go index ee34b515f..7247d4ed1 100644 --- a/packer/rpc/muxconn.go +++ b/packer/rpc/muxconn.go @@ -18,8 +18,9 @@ import ( // are established using a subset of the TCP protocol. Only a subset is // necessary since we assume ordering on the underlying RWC. type MuxConn struct { + curId uint32 rwc io.ReadWriteCloser - streams map[byte]*Stream + streams map[uint32]*Stream mu sync.RWMutex wlock sync.Mutex } @@ -36,7 +37,7 @@ const ( func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn { m := &MuxConn{ rwc: rwc, - streams: make(map[byte]*Stream), + streams: make(map[uint32]*Stream), } go m.loop() @@ -54,14 +55,14 @@ func (m *MuxConn) Close() error { for _, w := range m.streams { w.Close() } - m.streams = make(map[byte]*Stream) + m.streams = make(map[uint32]*Stream) return m.rwc.Close() } // Accept accepts a multiplexed connection with the given ID. This // will block until a request is made to connect. -func (m *MuxConn) Accept(id byte) (io.ReadWriteCloser, error) { +func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) { stream, err := m.openStream(id) if err != nil { return nil, err @@ -113,7 +114,7 @@ func (m *MuxConn) Accept(id byte) (io.ReadWriteCloser, error) { // Dial opens a connection to the remote end using the given stream ID. // An Accept on the remote end will only work with if the IDs match. -func (m *MuxConn) Dial(id byte) (io.ReadWriteCloser, error) { +func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) { stream, err := m.openStream(id) if err != nil { return nil, err @@ -149,7 +150,22 @@ func (m *MuxConn) Dial(id byte) (io.ReadWriteCloser, error) { } } -func (m *MuxConn) openStream(id byte) (*Stream, error) { +// NextId returns the next available stream ID that isn't currently +// taken. +func (m *MuxConn) NextId() uint32 { + m.mu.Lock() + defer m.mu.Unlock() + + for { + if _, ok := m.streams[m.curId]; !ok { + return m.curId + } + + m.curId++ + } +} + +func (m *MuxConn) openStream(id uint32) (*Stream, error) { m.mu.Lock() defer m.mu.Unlock() @@ -176,7 +192,7 @@ func (m *MuxConn) openStream(id byte) (*Stream, error) { func (m *MuxConn) loop() { defer m.Close() - var id byte + var id uint32 var packetType muxPacketType var length int32 for { @@ -249,7 +265,7 @@ func (m *MuxConn) loop() { } } -func (m *MuxConn) write(id byte, dataType muxPacketType, p []byte) (int, error) { +func (m *MuxConn) write(id uint32, dataType muxPacketType, p []byte) (int, error) { m.wlock.Lock() defer m.wlock.Unlock() @@ -270,7 +286,7 @@ func (m *MuxConn) write(id byte, dataType muxPacketType, p []byte) (int, error) // Stream is a single stream of data and implements io.ReadWriteCloser type Stream struct { - id byte + id uint32 mux *MuxConn reader io.Reader writer io.WriteCloser