From edbdee5dee43a6666247403601ec133e05ea822c Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Fri, 20 Dec 2013 09:49:44 -0800 Subject: [PATCH] packer/rpc: accept/dial stream IDs are unique [GH-727] --- packer/rpc/client.go | 2 +- packer/rpc/communicator.go | 7 +- packer/rpc/muxconn.go | 318 +++++++++++++++++++++++-------------- packer/rpc/muxconn_test.go | 8 +- packer/rpc/server.go | 2 +- 5 files changed, 207 insertions(+), 130 deletions(-) diff --git a/packer/rpc/client.go b/packer/rpc/client.go index 95b5a6d8d..0fce79bf8 100644 --- a/packer/rpc/client.go +++ b/packer/rpc/client.go @@ -17,7 +17,7 @@ type Client struct { } func NewClient(rwc io.ReadWriteCloser) (*Client, error) { - result, err := newClientWithMux(NewMuxConn(rwc, 0), 0) + result, err := newClientWithMux(NewMuxConn(rwc), 0) if err != nil { return nil, err } diff --git a/packer/rpc/communicator.go b/packer/rpc/communicator.go index a1ec979eb..a3d267091 100644 --- a/packer/rpc/communicator.go +++ b/packer/rpc/communicator.go @@ -164,7 +164,7 @@ func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface } }() - if args.StdinStreamId > 0 { + if args.StdinStreamId >= 0 { conn, err := c.mux.Dial(args.StdinStreamId) if err != nil { close(doneCh) @@ -175,7 +175,7 @@ func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface cmd.Stdin = conn } - if args.StdoutStreamId > 0 { + if args.StdoutStreamId >= 0 { conn, err := c.mux.Dial(args.StdoutStreamId) if err != nil { close(doneCh) @@ -186,7 +186,7 @@ func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface cmd.Stdout = conn } - if args.StderrStreamId > 0 { + if args.StderrStreamId >= 0 { conn, err := c.mux.Dial(args.StderrStreamId) if err != nil { close(doneCh) @@ -253,7 +253,6 @@ func (c *CommunicatorServer) Download(args *CommunicatorDownloadArgs, reply *int } func serveSingleCopy(name string, mux *MuxConn, id uint32, dst io.Writer, src io.Reader) { - log.Printf("[DEBUG] %s: Connecting to stream %d", name, id) conn, err := mux.Accept(id) if err != nil { log.Printf("[ERR] '%s' accept error: %s", name, err) diff --git a/packer/rpc/muxconn.go b/packer/rpc/muxconn.go index 6718abb7c..0b51d0d6f 100644 --- a/packer/rpc/muxconn.go +++ b/packer/rpc/muxconn.go @@ -22,16 +22,25 @@ import ( // are established using a subset of the TCP protocol. Only a subset is // necessary since we assume ordering on the underlying RWC. type MuxConn struct { - curId uint32 - rwc io.ReadWriteCloser - streams map[uint32]*Stream - mu sync.RWMutex - wlock sync.Mutex - doneCh chan struct{} + curId uint32 + rwc io.ReadWriteCloser + streamsAccept map[uint32]*Stream + streamsDial map[uint32]*Stream + mu sync.RWMutex + muAccept sync.RWMutex + muDial sync.RWMutex + wlock sync.Mutex + doneCh chan struct{} } +type muxPacketFrom byte type muxPacketType byte +const ( + muxPacketFromAccept muxPacketFrom = iota + muxPacketFromDial +) + const ( muxPacketSyn muxPacketType = iota muxPacketSynAck @@ -40,13 +49,24 @@ const ( muxPacketData ) +func (f muxPacketFrom) String() string { + switch f { + case muxPacketFromAccept: + return "accept" + case muxPacketFromDial: + return "dial" + default: + panic("unknown from type") + } +} + // Create a new MuxConn around any io.ReadWriteCloser. -func NewMuxConn(rwc io.ReadWriteCloser, startId uint32) *MuxConn { +func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn { m := &MuxConn{ - rwc: rwc, - streams: make(map[uint32]*Stream), - doneCh: make(chan struct{}), - curId: startId, + rwc: rwc, + streamsAccept: make(map[uint32]*Stream), + streamsDial: make(map[uint32]*Stream), + doneCh: make(chan struct{}), } go m.cleaner() @@ -62,10 +82,14 @@ func (m *MuxConn) Close() error { defer m.mu.Unlock() // Close all the streams - for _, w := range m.streams { + for _, w := range m.streamsAccept { w.Close() } - m.streams = make(map[uint32]*Stream) + for _, w := range m.streamsDial { + w.Close() + } + m.streamsAccept = make(map[uint32]*Stream) + m.streamsDial = make(map[uint32]*Stream) // Close the actual connection. This will also force the loop // to end since it'll read EOF or closed connection. @@ -75,14 +99,22 @@ func (m *MuxConn) Close() error { // Accept accepts a multiplexed connection with the given ID. This // will block until a request is made to connect. func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) { - stream, err := m.openStream(id) - if err != nil { - return nil, err - } + //log.Printf("[TRACE] %p: Accept on stream ID: %d", m, id) + + // Get the stream. It is okay if it is already in the list of streams + // because we may have prematurely received a syn for it. + m.muAccept.Lock() + stream, ok := m.streamsAccept[id] + if !ok { + stream = newStream(muxPacketFromAccept, id, m) + m.streamsAccept[id] = stream + } + m.muAccept.Unlock() - // If the stream isn't closed, then it is already open somehow stream.mu.Lock() defer stream.mu.Unlock() + + // If the stream isn't closed, then it is already open somehow if stream.state != streamStateSynRecv && stream.state != streamStateClosed { return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state) } @@ -97,7 +129,7 @@ func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) { if stream.state == streamStateSynRecv { // Send a syn-ack - if _, err := m.write(stream.id, muxPacketSynAck, nil); err != nil { + if _, err := stream.write(muxPacketSynAck, nil); err != nil { return nil, err } } @@ -112,110 +144,68 @@ func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) { // Dial opens a connection to the remote end using the given stream ID. // An Accept on the remote end will only work with if the IDs match. func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) { - stream, err := m.openStream(id) - if err != nil { - return nil, err + m.muDial.Lock() + + // If we have any streams with this ID, then it is a failure. The + // reaper should clear out old streams once in awhile. + if stream, ok := m.streamsDial[id]; ok { + m.muDial.Unlock() + return nil, fmt.Errorf( + "Stream %d already open for dial. State: %d", stream.state) } - // If the stream isn't closed, then it is already open somehow + // Create the new stream and put it in our list. We can then + // unlock because dialing will no longer be allowed on that ID. + stream := newStream(muxPacketFromDial, id, m) + m.streamsDial[id] = stream + + // Don't let anyone else mess with this stream stream.mu.Lock() defer stream.mu.Unlock() - if stream.state != streamStateClosed { - return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state) - } + + m.muDial.Unlock() // Open a connection - if _, err := m.write(stream.id, muxPacketSyn, nil); err != nil { + if _, err := stream.write(muxPacketSyn, nil); err != nil { return nil, err } + + // It is safe to set the state after the write above because + // we hold the stream lock. stream.setState(streamStateSynSent) if err := stream.waitState(streamStateEstablished); err != nil { return nil, err } - m.write(id, muxPacketAck, nil) + stream.write(muxPacketAck, nil) return stream, nil } -// NextId returns the next available stream ID that isn't currently +// NextId returns the next available listen stream ID that isn't currently // taken. func (m *MuxConn) NextId() uint32 { - m.mu.Lock() - defer m.mu.Unlock() + m.muAccept.Lock() + defer m.muAccept.Unlock() for { result := m.curId - m.curId += 2 - if _, ok := m.streams[result]; !ok { + m.curId += 1 + if _, ok := m.streamsAccept[result]; !ok { return result } } } -func (m *MuxConn) openStream(id uint32) (*Stream, error) { - // First grab a read-lock if we have the stream already we can - // cheaply return it. - m.mu.RLock() - if stream, ok := m.streams[id]; ok { - m.mu.RUnlock() - return stream, nil - } - - // Now acquire a full blown write lock so we can create the stream - m.mu.RUnlock() - m.mu.Lock() - defer m.mu.Unlock() - - // Make sure we attempt to use the next biggest stream ID - if id >= m.curId { - m.curId = id + 1 - } - - // We have to check this again because there is a time period - // above where we couldn't lost this lock. - if stream, ok := m.streams[id]; ok { - return stream, nil - } - - // Create the stream object and channel where data will be sent to - dataR, dataW := io.Pipe() - writeCh := make(chan []byte, 256) - - // Set the data channel so we can write to it. - stream := &Stream{ - id: id, - mux: m, - reader: dataR, - writeCh: writeCh, - stateChange: make(map[chan<- streamState]struct{}), - } - stream.setState(streamStateClosed) - - // Start the goroutine that will read from the queue and write - // data out. - go func() { - defer dataW.Close() - - for { - data := <-writeCh - if data == nil { - // A nil is a tombstone letting us know we're done - // accepting data. - return - } - - if _, err := dataW.Write(data); err != nil { - return - } - } - }() - - m.streams[id] = stream - return m.streams[id], nil -} - func (m *MuxConn) cleaner() { + checks := []struct { + Map map[uint32]*Stream + Lock *sync.RWMutex + }{ + {m.streamsAccept, &m.muAccept}, + {m.streamsDial, &m.muDial}, + } + for { done := false select { @@ -224,23 +214,28 @@ func (m *MuxConn) cleaner() { done = true } - m.mu.Lock() - for id, s := range m.streams { - s.mu.Lock() - if s.state == streamStateClosed { - delete(m.streams, id) - } - s.mu.Unlock() - } - - if done { - for _, s := range m.streams { + for _, check := range checks { + check.Lock.Lock() + for id, s := range check.Map { s.mu.Lock() - s.closeWriter() + + if done && s.state != streamStateClosed { + s.closeWriter() + } + + if s.state == streamStateClosed { + // Only clean up the streams that have been closed + // for a certain amount of time. + since := time.Now().UTC().Sub(s.stateUpdated) + if since > 2*time.Second { + delete(check.Map, id) + } + } + s.mu.Unlock() } + check.Lock.Unlock() } - m.mu.Unlock() if done { return @@ -256,10 +251,15 @@ func (m *MuxConn) loop() { close(m.doneCh) }() + var from muxPacketFrom var id uint32 var packetType muxPacketType var length int32 for { + if err := binary.Read(m.rwc, binary.BigEndian, &from); err != nil { + log.Printf("[ERR] Error reading stream direction: %s", err) + return + } if err := binary.Read(m.rwc, binary.BigEndian, &id); err != nil { log.Printf("[ERR] Error reading stream ID: %s", err) return @@ -282,15 +282,47 @@ func (m *MuxConn) loop() { } } - stream, err := m.openStream(id) - if err != nil { - log.Printf("[ERR] Error opening stream %d: %s", id, err) - return + // Get the proper stream. Note that the map we look into is + // opposite the "from" because if the dial side is talking to + // us, we need to look into the accept map, and so on. + // + // Note: we also switch the "from" value so that logging + // below is correct. + var stream *Stream + switch from { + case muxPacketFromDial: + m.muAccept.Lock() + stream = m.streamsAccept[id] + m.muAccept.Unlock() + + from = muxPacketFromAccept + case muxPacketFromAccept: + m.muDial.Lock() + stream = m.streamsDial[id] + m.muDial.Unlock() + + from = muxPacketFromDial + default: + panic(fmt.Sprintf("Unknown stream direction: %d", from)) } - //log.Printf("[TRACE] Stream %d received packet %d", id, packetType) + //log.Printf("[TRACE] %p: Stream %d (%s) received packet %d", m, id, from, packetType) switch packetType { case muxPacketSyn: + // If the stream is nil, this is the only case where we'll + // automatically create the stream struct. + if stream == nil { + var ok bool + + m.muAccept.Lock() + stream, ok = m.streamsAccept[id] + if !ok { + stream = newStream(muxPacketFromAccept, id, m) + m.streamsAccept[id] = stream + } + m.muAccept.Unlock() + } + stream.mu.Lock() switch stream.state { case streamStateClosed: @@ -332,15 +364,15 @@ func (m *MuxConn) loop() { case streamStateEstablished: stream.closeWriter() stream.setState(streamStateCloseWait) - m.write(id, muxPacketAck, nil) + stream.write(muxPacketAck, nil) case streamStateFinWait2: stream.closeWriter() stream.setState(streamStateClosed) - m.write(id, muxPacketAck, nil) + stream.write(muxPacketAck, nil) case streamStateFinWait1: stream.closeWriter() stream.setState(streamStateClosing) - m.write(id, muxPacketAck, nil) + stream.write(muxPacketAck, nil) default: log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state) } @@ -367,10 +399,13 @@ func (m *MuxConn) loop() { } } -func (m *MuxConn) write(id uint32, dataType muxPacketType, p []byte) (int, error) { +func (m *MuxConn) write(from muxPacketFrom, id uint32, dataType muxPacketType, p []byte) (int, error) { m.wlock.Lock() defer m.wlock.Unlock() + if err := binary.Write(m.rwc, binary.BigEndian, from); err != nil { + return 0, err + } if err := binary.Write(m.rwc, binary.BigEndian, id); err != nil { return 0, err } @@ -389,6 +424,7 @@ func (m *MuxConn) write(id uint32, dataType muxPacketType, p []byte) (int, error // Stream is a single stream of data and implements io.ReadWriteCloser. // A Stream is full-duplex so you can write data as well as read data. type Stream struct { + from muxPacketFrom id uint32 mux *MuxConn reader io.Reader @@ -414,6 +450,44 @@ const ( streamStateLastAck ) +func newStream(from muxPacketFrom, id uint32, m *MuxConn) *Stream { + // Create the stream object and channel where data will be sent to + dataR, dataW := io.Pipe() + writeCh := make(chan []byte, 256) + + // Set the data channel so we can write to it. + stream := &Stream{ + from: from, + id: id, + mux: m, + reader: dataR, + writeCh: writeCh, + stateChange: make(map[chan<- streamState]struct{}), + } + stream.setState(streamStateClosed) + + // Start the goroutine that will read from the queue and write + // data out. + go func() { + defer dataW.Close() + + for { + data := <-writeCh + if data == nil { + // A nil is a tombstone letting us know we're done + // accepting data. + return + } + + if _, err := dataW.Write(data); err != nil { + return + } + } + }() + + return stream +} + func (s *Stream) Close() error { s.mu.Lock() defer s.mu.Unlock() @@ -428,7 +502,7 @@ func (s *Stream) Close() error { s.setState(streamStateLastAck) } - s.mux.write(s.id, muxPacketFin, nil) + s.write(muxPacketFin, nil) return nil } @@ -445,7 +519,7 @@ func (s *Stream) Write(p []byte) (int, error) { return 0, fmt.Errorf("Stream %d in bad state to send: %d", s.id, state) } - return s.mux.write(s.id, muxPacketData, p) + return s.write(muxPacketData, p) } func (s *Stream) closeWriter() { @@ -453,7 +527,7 @@ func (s *Stream) closeWriter() { } func (s *Stream) setState(state streamState) { - //log.Printf("[TRACE] Stream %d went to state %d", s.id, state) + //log.Printf("[TRACE] %p: Stream %d (%s) went to state %d", s.mux, s.id, s.from, state) s.state = state s.stateUpdated = time.Now().UTC() for ch, _ := range s.stateChange { @@ -482,3 +556,7 @@ func (s *Stream) waitState(target streamState) error { return fmt.Errorf("Stream %d went to bad state: %d", s.id, state) } } + +func (s *Stream) write(dataType muxPacketType, p []byte) (int, error) { + return s.mux.write(s.from, s.id, dataType, p) +} diff --git a/packer/rpc/muxconn_test.go b/packer/rpc/muxconn_test.go index 910bf66e1..a2f76bf3b 100644 --- a/packer/rpc/muxconn_test.go +++ b/packer/rpc/muxconn_test.go @@ -33,7 +33,7 @@ func testMux(t *testing.T) (client *MuxConn, server *MuxConn) { t.Fatalf("err: %s", err) } - server = NewMuxConn(conn, 1) + server = NewMuxConn(conn) }() // Client side @@ -41,7 +41,7 @@ func testMux(t *testing.T) (client *MuxConn, server *MuxConn) { if err != nil { t.Fatalf("err: %s", err) } - client = NewMuxConn(conn, 0) + client = NewMuxConn(conn) // Wait for the server <-doneCh @@ -241,14 +241,14 @@ func TestMuxConnNextId(t *testing.T) { a := client.NextId() b := client.NextId() - if a != 0 || b != 2 { + if a != 0 || b != 1 { t.Fatalf("IDs should increment") } a = server.NextId() b = server.NextId() - if a != 1 || b != 3 { + if a != 0 || b != 1 { t.Fatalf("IDs should increment: %d %d", a, b) } } diff --git a/packer/rpc/server.go b/packer/rpc/server.go index 3537c8e7f..5e4b5006b 100644 --- a/packer/rpc/server.go +++ b/packer/rpc/server.go @@ -36,7 +36,7 @@ type Server struct { // NewServer returns a new Packer RPC server. func NewServer(conn io.ReadWriteCloser) *Server { - result := newServerWithMux(NewMuxConn(conn, 1), 0) + result := newServerWithMux(NewMuxConn(conn), 0) result.closeMux = true return result }