diff --git a/packer/rpc/muxconn.go b/packer/rpc/muxconn.go index 46f2f43ea..da61c3673 100644 --- a/packer/rpc/muxconn.go +++ b/packer/rpc/muxconn.go @@ -27,6 +27,7 @@ type MuxConn struct { streams map[uint32]*Stream mu sync.RWMutex wlock sync.Mutex + doneCh chan struct{} } type muxPacketType byte @@ -44,8 +45,10 @@ func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn { m := &MuxConn{ rwc: rwc, streams: make(map[uint32]*Stream), + doneCh: make(chan struct{}), } + go m.cleaner() go m.loop() return m @@ -211,18 +214,45 @@ func (m *MuxConn) openStream(id uint32) (*Stream, error) { return m.streams[id], nil } +func (m *MuxConn) cleaner() { + for { + done := false + select { + case <-time.After(500 * time.Millisecond): + case <-m.doneCh: + done = true + } + + m.mu.Lock() + for id, s := range m.streams { + s.mu.Lock() + if s.state == streamStateClosed { + delete(m.streams, id) + } + s.mu.Unlock() + } + + if done { + for _, s := range m.streams { + s.mu.Lock() + s.closeWriter() + s.mu.Unlock() + } + } + m.mu.Unlock() + + if done { + return + } + } +} + func (m *MuxConn) loop() { // Force close every stream that we know about when we exit so // that they all read EOF and don't block forever. defer func() { log.Printf("[INFO] Mux connection loop exiting") - m.mu.Lock() - defer m.mu.Unlock() - for _, w := range m.streams { - w.mu.Lock() - w.remoteClose() - w.mu.Unlock() - } + close(m.doneCh) }() var id uint32 @@ -277,6 +307,11 @@ func (m *MuxConn) loop() { stream.setState(streamStateEstablished) case streamStateFinWait1: stream.setState(streamStateFinWait2) + case streamStateLastAck: + stream.closeWriter() + fallthrough + case streamStateClosing: + stream.setState(streamStateClosed) default: log.Printf("[ERR] Ack received for stream in state: %d", stream.state) } @@ -294,20 +329,17 @@ func (m *MuxConn) loop() { stream.mu.Lock() switch stream.state { case streamStateEstablished: + stream.closeWriter() stream.setState(streamStateCloseWait) m.write(id, muxPacketAck, nil) - - // Close the writer on our end since we won't receive any - // more data. - stream.writeCh <- nil - case streamStateFinWait1: - fallthrough case streamStateFinWait2: - stream.remoteClose() - - m.mu.Lock() - delete(m.streams, stream.id) - m.mu.Unlock() + stream.closeWriter() + stream.setState(streamStateClosed) + m.write(id, muxPacketAck, nil) + case streamStateFinWait1: + stream.closeWriter() + stream.setState(streamStateClosing) + m.write(id, muxPacketAck, nil) default: log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state) } @@ -377,6 +409,8 @@ const ( streamStateFinWait1 streamStateFinWait2 streamStateCloseWait + streamStateClosing + streamStateLastAck ) func (s *Stream) Close() error { @@ -390,7 +424,7 @@ func (s *Stream) Close() error { if s.state == streamStateEstablished { s.setState(streamStateFinWait1) } else { - s.remoteClose() + s.setState(streamStateLastAck) } s.mux.write(s.id, muxPacketFin, nil) @@ -413,8 +447,7 @@ func (s *Stream) Write(p []byte) (int, error) { return s.mux.write(s.id, muxPacketData, p) } -func (s *Stream) remoteClose() { - s.setState(streamStateClosed) +func (s *Stream) closeWriter() { s.writeCh <- nil }