diff --git a/packer/rpc/communicator.go b/packer/rpc/communicator.go index 0931953cc..a1ec979eb 100644 --- a/packer/rpc/communicator.go +++ b/packer/rpc/communicator.go @@ -253,9 +253,10 @@ func (c *CommunicatorServer) Download(args *CommunicatorDownloadArgs, reply *int } func serveSingleCopy(name string, mux *MuxConn, id uint32, dst io.Writer, src io.Reader) { + log.Printf("[DEBUG] %s: Connecting to stream %d", name, id) conn, err := mux.Accept(id) if err != nil { - log.Printf("'%s' accept error: %s", name, err) + log.Printf("[ERR] '%s' accept error: %s", name, err) return } @@ -271,8 +272,8 @@ func serveSingleCopy(name string, mux *MuxConn, id uint32, dst io.Writer, src io } written, err := io.Copy(dst, src) - log.Printf("%d bytes written for '%s'", written, name) + log.Printf("[INFO] %d bytes written for '%s'", written, name) if err != nil { - log.Printf("'%s' copy error: %s", name, err) + log.Printf("[ERR] '%s' copy error: %s", name, err) } } diff --git a/packer/rpc/muxconn.go b/packer/rpc/muxconn.go index d2ec5f6b2..174b3e36d 100644 --- a/packer/rpc/muxconn.go +++ b/packer/rpc/muxconn.go @@ -33,6 +33,7 @@ type muxPacketType byte const ( muxPacketSyn muxPacketType = iota + muxPacketSynAck muxPacketAck muxPacketFin muxPacketData @@ -77,49 +78,27 @@ func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) { // If the stream isn't closed, then it is already open somehow stream.mu.Lock() + defer stream.mu.Unlock() if stream.state != streamStateSynRecv && stream.state != streamStateClosed { - stream.mu.Unlock() return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state) } - if stream.state == streamStateSynRecv { - // Fast track establishing since we already got the syn - stream.setState(streamStateEstablished) - stream.mu.Unlock() - } - - if stream.state != streamStateEstablished { - // Go into the listening state + if stream.state == streamStateClosed { + // Go into the listening state and wait for a syn stream.setState(streamStateListen) - - // Register a state change listener to wait for changes - stateCh := make(chan streamState, 10) - stream.registerStateListener(stateCh) - defer func() { - stream.mu.Lock() - defer stream.mu.Unlock() - stream.deregisterStateListener(stateCh) - }() - - stream.mu.Unlock() - - // Wait for the connection to establish - ACCEPT_ESTABLISH_LOOP: - for { - state := <-stateCh - switch state { - case streamStateListen: - case streamStateEstablished: - break ACCEPT_ESTABLISH_LOOP - default: - defer stream.mu.Unlock() - return nil, fmt.Errorf("Stream %d went to bad state: %d", id, stream.state) - } + if err := stream.waitState(streamStateSynRecv); err != nil { + return nil, err } } - // Send the ack down - if _, err := m.write(stream.id, muxPacketAck, nil); err != nil { + if stream.state == streamStateSynRecv { + // Send a syn-ack + if _, err := m.write(stream.id, muxPacketSynAck, nil); err != nil { + return nil, err + } + } + + if err := stream.waitState(streamStateEstablished); err != nil { return nil, err } @@ -136,8 +115,8 @@ func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) { // If the stream isn't closed, then it is already open somehow stream.mu.Lock() + defer stream.mu.Unlock() if stream.state != streamStateClosed { - stream.mu.Unlock() return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state) } @@ -147,28 +126,12 @@ func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) { } stream.setState(streamStateSynSent) - // Register a state change listener to wait for changes - stateCh := make(chan streamState, 10) - stream.registerStateListener(stateCh) - defer func() { - stream.mu.Lock() - defer stream.mu.Unlock() - stream.deregisterStateListener(stateCh) - }() - - stream.mu.Unlock() - - for { - state := <-stateCh - switch state { - case streamStateSynSent: - case streamStateEstablished: - return stream, nil - default: - defer stream.mu.Unlock() - return nil, fmt.Errorf("Stream %d went to bad state: %d", id, stream.state) - } + if err := stream.waitState(streamStateEstablished); err != nil { + return nil, err } + + m.write(id, muxPacketAck, nil) + return stream, nil } // NextId returns the next available stream ID that isn't currently @@ -247,6 +210,7 @@ func (m *MuxConn) loop() { // Force close every stream that we know about when we exit so // that they all read EOF and don't block forever. defer func() { + log.Printf("[INFO] Mux connection loop exiting") m.mu.Lock() defer m.mu.Unlock() for _, w := range m.streams { @@ -288,12 +252,23 @@ func (m *MuxConn) loop() { return } - //log.Printf("[DEBUG] Stream %d received packet %d", id, packetType) + log.Printf("[TRACE] Stream %d received packet %d", id, packetType) switch packetType { + case muxPacketSyn: + stream.mu.Lock() + switch stream.state { + case streamStateClosed: + fallthrough + case streamStateListen: + stream.setState(streamStateSynRecv) + default: + log.Printf("[ERR] Syn received for stream in state: %d", stream.state) + } + stream.mu.Unlock() case muxPacketAck: stream.mu.Lock() switch stream.state { - case streamStateSynSent: + case streamStateSynRecv: stream.setState(streamStateEstablished) case streamStateFinWait1: stream.setState(streamStateFinWait2) @@ -301,15 +276,13 @@ func (m *MuxConn) loop() { log.Printf("[ERR] Ack received for stream in state: %d", stream.state) } stream.mu.Unlock() - case muxPacketSyn: + case muxPacketSynAck: stream.mu.Lock() switch stream.state { - case streamStateClosed: - stream.setState(streamStateSynRecv) - case streamStateListen: + case streamStateSynSent: stream.setState(streamStateEstablished) default: - log.Printf("[ERR] Syn received for stream in state: %d", stream.state) + log.Printf("[ERR] SynAck received for stream in state: %d", stream.state) } stream.mu.Unlock() case muxPacketFin: @@ -451,6 +424,7 @@ func (s *Stream) deregisterStateListener(ch chan<- streamState) { } func (s *Stream) setState(state streamState) { + log.Printf("[TRACE] Stream %d went to state %d", s.id, state) s.state = state s.stateUpdated = time.Now().UTC() for ch, _ := range s.stateChange { @@ -460,3 +434,22 @@ func (s *Stream) setState(state streamState) { } } } + +func (s *Stream) waitState(target streamState) error { + // Register a state change listener to wait for changes + stateCh := make(chan streamState, 10) + s.registerStateListener(stateCh) + s.mu.Unlock() + + defer func() { + s.mu.Lock() + s.deregisterStateListener(stateCh) + }() + + state := <-stateCh + if state == target { + return nil + } else { + return fmt.Errorf("Stream %d went to bad state: %d", s.id, state) + } +}