diff --git a/packer/rpc/muxconn.go b/packer/rpc/muxconn.go index 13e864b5f..f5f0ce804 100644 --- a/packer/rpc/muxconn.go +++ b/packer/rpc/muxconn.go @@ -21,7 +21,7 @@ type MuxConn struct { curId uint32 rwc io.ReadWriteCloser streams map[uint32]*Stream - mu sync.Mutex + mu sync.RWMutex wlock sync.Mutex } @@ -48,8 +48,8 @@ func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn { // Close closes the underlying io.ReadWriteCloser. This will also close // all streams that are open. func (m *MuxConn) Close() error { - m.mu.Lock() - defer m.mu.Unlock() + m.mu.RLock() + defer m.mu.RUnlock() // Close all the streams for _, w := range m.streams { @@ -94,12 +94,17 @@ func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) { switch stream.state { case streamStateListen: stream.mu.Unlock() + case streamStateClosed: + // This can happen if it becomes established, some data is sent, + // and it closed all within the time period we wait above. + // This case will be fixed when we have edge-triggered checks. + fallthrough case streamStateEstablished: stream.mu.Unlock() break ACCEPT_ESTABLISH_LOOP default: defer stream.mu.Unlock() - return nil, fmt.Errorf("Stream went to bad state: %d", stream.state) + return nil, fmt.Errorf("Stream %d went to bad state: %d", id, stream.state) } } } @@ -140,12 +145,17 @@ func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) { switch stream.state { case streamStateSynSent: stream.mu.Unlock() + case streamStateClosed: + // This can happen if it becomes established, some data is sent, + // and it closed all within the time period we wait above. + // This case will be fixed when we have edge-triggered checks. + fallthrough case streamStateEstablished: stream.mu.Unlock() return stream, nil default: defer stream.mu.Unlock() - return nil, fmt.Errorf("Stream went to bad state: %d", stream.state) + return nil, fmt.Errorf("Stream %d went to bad state: %d", id, stream.state) } } } @@ -166,9 +176,21 @@ func (m *MuxConn) NextId() uint32 { } func (m *MuxConn) openStream(id uint32) (*Stream, error) { + // First grab a read-lock if we have the stream already we can + // cheaply return it. + m.mu.RLock() + if stream, ok := m.streams[id]; ok { + m.mu.RUnlock() + return stream, nil + } + + // Now acquire a full blown write lock so we can create the stream + m.mu.RUnlock() m.mu.Lock() defer m.mu.Unlock() + // We have to check this again because there is a time period + // above where we couldn't lost this lock. if stream, ok := m.streams[id]; ok { return stream, nil } @@ -182,7 +204,6 @@ func (m *MuxConn) openStream(id uint32) (*Stream, error) { id: id, mux: m, reader: dataR, - writer: dataW, writeCh: writeCh, } stream.setState(streamStateClosed) @@ -190,8 +211,16 @@ func (m *MuxConn) openStream(id uint32) (*Stream, error) { // Start the goroutine that will read from the queue and write // data out. go func() { + defer dataW.Close() + for { data := <-writeCh + if data == nil { + // A nil is a tombstone letting us know we're done + // accepting data. + return + } + if _, err := dataW.Write(data); err != nil { return } @@ -237,12 +266,16 @@ func (m *MuxConn) loop() { return } + log.Printf("[DEBUG] Stream %d received packet %d", id, packetType) switch packetType { case muxPacketAck: stream.mu.Lock() - if stream.state == streamStateSynSent { + switch stream.state { + case streamStateSynSent: stream.setState(streamStateEstablished) - } else { + case streamStateFinWait1: + stream.remoteClose() + default: log.Printf("[ERR] Ack received for stream in state: %d", stream.state) } stream.mu.Unlock() @@ -259,13 +292,23 @@ func (m *MuxConn) loop() { stream.mu.Unlock() case muxPacketFin: stream.mu.Lock() - stream.setState(streamStateClosed) - stream.writer.Close() + switch stream.state { + case streamStateEstablished: + m.write(id, muxPacketAck, nil) + fallthrough + case streamStateFinWait1: + stream.remoteClose() + + // Remove this stream from being active so that it + // can be re-used + m.mu.Lock() + delete(m.streams, stream.id) + m.mu.Unlock() + default: + log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state) + } stream.mu.Unlock() - m.mu.Lock() - delete(m.streams, stream.id) - m.mu.Unlock() case muxPacketData: stream.mu.Lock() if stream.state == streamStateEstablished { @@ -306,7 +349,6 @@ type Stream struct { id uint32 mux *MuxConn reader io.Reader - writer io.WriteCloser state streamState stateUpdated time.Time mu sync.Mutex @@ -321,23 +363,37 @@ const ( streamStateSynRecv streamStateSynSent streamStateEstablished - streamStateFinWait + streamStateFinWait1 ) func (s *Stream) Close() error { s.mu.Lock() - defer s.mu.Unlock() - if s.state != streamStateEstablished { + s.mu.Unlock() return fmt.Errorf("Stream in bad state: %d", s.state) } if _, err := s.mux.write(s.id, muxPacketFin, nil); err != nil { return err } + s.setState(streamStateFinWait1) + s.mu.Unlock() + + for { + time.Sleep(50 * time.Millisecond) + s.mu.Lock() + switch s.state { + case streamStateFinWait1: + s.mu.Unlock() + case streamStateClosed: + s.mu.Unlock() + return nil + default: + defer s.mu.Unlock() + return fmt.Errorf("Stream %d went to bad state: %d", s.id, s.state) + } + } - s.setState(streamStateClosed) - s.writer.Close() return nil } @@ -346,9 +402,22 @@ func (s *Stream) Read(p []byte) (int, error) { } func (s *Stream) Write(p []byte) (int, error) { + s.mu.Lock() + state := s.state + s.mu.Unlock() + + if state != streamStateEstablished { + return 0, fmt.Errorf("Stream in bad state to send: %d", state) + } + return s.mux.write(s.id, muxPacketData, p) } +func (s *Stream) remoteClose() { + s.setState(streamStateClosed) + s.writeCh <- nil +} + func (s *Stream) setState(state streamState) { s.state = state s.stateUpdated = time.Now().UTC()