From ae37050e8a89193ea9ba621729a4708172a029dd Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Mon, 30 Dec 2013 21:03:10 -0800 Subject: [PATCH] packer/rpc: muxconn can't use stream ID 0 ever --- CHANGELOG.md | 2 + packer/rpc/muxconn.go | 492 +++++++++++++++++++++--------------------- 2 files changed, 248 insertions(+), 246 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e2c23b302..00091b9d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,8 @@ ## 0.5.1 (unreleased) +BUG FIXES: +* core: If a stream ID loops around, don't let it use stream ID 0 [GH-767] ## 0.5.0 (12/30/2013) diff --git a/packer/rpc/muxconn.go b/packer/rpc/muxconn.go index db8faabf6..956e21aed 100644 --- a/packer/rpc/muxconn.go +++ b/packer/rpc/muxconn.go @@ -191,13 +191,13 @@ func (m *MuxConn) NextId() uint32 { m.muAccept.Lock() defer m.muAccept.Unlock() - // We never use stream ID 0 because 0 is the zero value of a uint32 - // and we want to reserve that for "not in use" - if m.curId == 0 { - m.curId = 1 - } - for { + // We never use stream ID 0 because 0 is the zero value of a uint32 + // and we want to reserve that for "not in use" + if m.curId == 0 { + m.curId = 1 + } + result := m.curId m.curId += 1 if _, ok := m.streamsAccept[result]; !ok { @@ -319,263 +319,263 @@ func (m *MuxConn) loop() { log.Printf( "[WARN] %p: Non-existent stream %d (%s) received packer %d", m, id, from, packetType) - continue + continue + } + + //log.Printf("[TRACE] %p: Stream %d (%s) received packet %d", m, id, from, packetType) + switch packetType { + case muxPacketSyn: + // If the stream is nil, this is the only case where we'll + // automatically create the stream struct. + if stream == nil { + var ok bool + + m.muAccept.Lock() + stream, ok = m.streamsAccept[id] + if !ok { + stream = newStream(muxPacketFromAccept, id, m) + m.streamsAccept[id] = stream + } + m.muAccept.Unlock() + } + + stream.mu.Lock() + switch stream.state { + case streamStateClosed: + fallthrough + case streamStateListen: + stream.setState(streamStateSynRecv) + default: + log.Printf("[ERR] Syn received for stream in state: %d", stream.state) + } + stream.mu.Unlock() + case muxPacketAck: + stream.mu.Lock() + switch stream.state { + case streamStateSynRecv: + stream.setState(streamStateEstablished) + case streamStateFinWait1: + stream.setState(streamStateFinWait2) + case streamStateLastAck: + stream.closeWriter() + fallthrough + case streamStateClosing: + stream.setState(streamStateClosed) + default: + log.Printf("[ERR] Ack received for stream in state: %d", stream.state) + } + stream.mu.Unlock() + case muxPacketSynAck: + stream.mu.Lock() + switch stream.state { + case streamStateSynSent: + stream.setState(streamStateEstablished) + default: + log.Printf("[ERR] SynAck received for stream in state: %d", stream.state) + } + stream.mu.Unlock() + case muxPacketFin: + stream.mu.Lock() + switch stream.state { + case streamStateEstablished: + stream.closeWriter() + stream.setState(streamStateCloseWait) + stream.write(muxPacketAck, nil) + case streamStateFinWait2: + stream.closeWriter() + stream.setState(streamStateClosed) + stream.write(muxPacketAck, nil) + case streamStateFinWait1: + stream.closeWriter() + stream.setState(streamStateClosing) + stream.write(muxPacketAck, nil) + default: + log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state) + } + stream.mu.Unlock() + + case muxPacketData: + stream.mu.Lock() + switch stream.state { + case streamStateFinWait1: + fallthrough + case streamStateFinWait2: + fallthrough + case streamStateEstablished: + if len(data) > 0 { + select { + case stream.writeCh <- data: + default: + panic(fmt.Sprintf( + "Failed to write data, buffer full for stream %d", id)) + } + } + default: + log.Printf("[ERR] Data received for stream in state: %d", stream.state) + } + stream.mu.Unlock() + } + } } - //log.Printf("[TRACE] %p: Stream %d (%s) received packet %d", m, id, from, packetType) - switch packetType { - case muxPacketSyn: - // If the stream is nil, this is the only case where we'll - // automatically create the stream struct. - if stream == nil { - var ok bool + func (m *MuxConn) write(from muxPacketFrom, id uint32, dataType muxPacketType, p []byte) (int, error) { + m.wlock.Lock() + defer m.wlock.Unlock() - m.muAccept.Lock() - stream, ok = m.streamsAccept[id] - if !ok { - stream = newStream(muxPacketFromAccept, id, m) - m.streamsAccept[id] = stream - } - m.muAccept.Unlock() + if err := binary.Write(m.rwc, binary.BigEndian, from); err != nil { + return 0, err } + if err := binary.Write(m.rwc, binary.BigEndian, id); err != nil { + return 0, err + } + if err := binary.Write(m.rwc, binary.BigEndian, byte(dataType)); err != nil { + return 0, err + } + if err := binary.Write(m.rwc, binary.BigEndian, int32(len(p))); err != nil { + return 0, err + } + if len(p) == 0 { + return 0, nil + } + return m.rwc.Write(p) + } - stream.mu.Lock() - switch stream.state { - case streamStateClosed: - fallthrough - case streamStateListen: - stream.setState(streamStateSynRecv) - default: - log.Printf("[ERR] Syn received for stream in state: %d", stream.state) - } - stream.mu.Unlock() - case muxPacketAck: - stream.mu.Lock() - switch stream.state { - case streamStateSynRecv: - stream.setState(streamStateEstablished) - case streamStateFinWait1: - stream.setState(streamStateFinWait2) - case streamStateLastAck: - stream.closeWriter() - fallthrough - case streamStateClosing: - stream.setState(streamStateClosed) - default: - log.Printf("[ERR] Ack received for stream in state: %d", stream.state) - } - stream.mu.Unlock() - case muxPacketSynAck: - stream.mu.Lock() - switch stream.state { - case streamStateSynSent: - stream.setState(streamStateEstablished) - default: - log.Printf("[ERR] SynAck received for stream in state: %d", stream.state) - } - stream.mu.Unlock() - case muxPacketFin: - stream.mu.Lock() - switch stream.state { - case streamStateEstablished: - stream.closeWriter() - stream.setState(streamStateCloseWait) - stream.write(muxPacketAck, nil) - case streamStateFinWait2: - stream.closeWriter() - stream.setState(streamStateClosed) - stream.write(muxPacketAck, nil) - case streamStateFinWait1: - stream.closeWriter() - stream.setState(streamStateClosing) - stream.write(muxPacketAck, nil) - default: - log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state) - } - stream.mu.Unlock() + // Stream is a single stream of data and implements io.ReadWriteCloser. + // A Stream is full-duplex so you can write data as well as read data. + type Stream struct { + from muxPacketFrom + id uint32 + mux *MuxConn + reader io.Reader + state streamState + stateChange map[chan<- streamState]struct{} + stateUpdated time.Time + mu sync.Mutex + writeCh chan<- []byte + } - case muxPacketData: - stream.mu.Lock() - switch stream.state { - case streamStateFinWait1: - fallthrough - case streamStateFinWait2: - fallthrough - case streamStateEstablished: - if len(data) > 0 { - select { - case stream.writeCh <- data: - default: - panic(fmt.Sprintf( - "Failed to write data, buffer full for stream %d", id)) + type streamState byte + + const ( + streamStateClosed streamState = iota + streamStateListen + streamStateSynRecv + streamStateSynSent + streamStateEstablished + streamStateFinWait1 + streamStateFinWait2 + streamStateCloseWait + streamStateClosing + streamStateLastAck + ) + + func newStream(from muxPacketFrom, id uint32, m *MuxConn) *Stream { + // Create the stream object and channel where data will be sent to + dataR, dataW := io.Pipe() + writeCh := make(chan []byte, 4096) + + // Set the data channel so we can write to it. + stream := &Stream{ + from: from, + id: id, + mux: m, + reader: dataR, + writeCh: writeCh, + stateChange: make(map[chan<- streamState]struct{}), + } + stream.setState(streamStateClosed) + + // Start the goroutine that will read from the queue and write + // data out. + go func() { + defer dataW.Close() + + for { + data := <-writeCh + if data == nil { + // A nil is a tombstone letting us know we're done + // accepting data. + return + } + + if _, err := dataW.Write(data); err != nil { + return } } - default: - log.Printf("[ERR] Data received for stream in state: %d", stream.state) - } - stream.mu.Unlock() + }() + + return stream } - } -} -func (m *MuxConn) write(from muxPacketFrom, id uint32, dataType muxPacketType, p []byte) (int, error) { - m.wlock.Lock() - defer m.wlock.Unlock() + func (s *Stream) Close() error { + s.mu.Lock() + defer s.mu.Unlock() - if err := binary.Write(m.rwc, binary.BigEndian, from); err != nil { - return 0, err - } - if err := binary.Write(m.rwc, binary.BigEndian, id); err != nil { - return 0, err - } - if err := binary.Write(m.rwc, binary.BigEndian, byte(dataType)); err != nil { - return 0, err - } - if err := binary.Write(m.rwc, binary.BigEndian, int32(len(p))); err != nil { - return 0, err - } - if len(p) == 0 { - return 0, nil - } - return m.rwc.Write(p) -} - -// Stream is a single stream of data and implements io.ReadWriteCloser. -// A Stream is full-duplex so you can write data as well as read data. -type Stream struct { - from muxPacketFrom - id uint32 - mux *MuxConn - reader io.Reader - state streamState - stateChange map[chan<- streamState]struct{} - stateUpdated time.Time - mu sync.Mutex - writeCh chan<- []byte -} - -type streamState byte - -const ( - streamStateClosed streamState = iota - streamStateListen - streamStateSynRecv - streamStateSynSent - streamStateEstablished - streamStateFinWait1 - streamStateFinWait2 - streamStateCloseWait - streamStateClosing - streamStateLastAck -) - -func newStream(from muxPacketFrom, id uint32, m *MuxConn) *Stream { - // Create the stream object and channel where data will be sent to - dataR, dataW := io.Pipe() - writeCh := make(chan []byte, 4096) - - // Set the data channel so we can write to it. - stream := &Stream{ - from: from, - id: id, - mux: m, - reader: dataR, - writeCh: writeCh, - stateChange: make(map[chan<- streamState]struct{}), - } - stream.setState(streamStateClosed) - - // Start the goroutine that will read from the queue and write - // data out. - go func() { - defer dataW.Close() - - for { - data := <-writeCh - if data == nil { - // A nil is a tombstone letting us know we're done - // accepting data. - return + if s.state != streamStateEstablished && s.state != streamStateCloseWait { + return fmt.Errorf("Stream in bad state: %d", s.state) } - if _, err := dataW.Write(data); err != nil { - return + if s.state == streamStateEstablished { + s.setState(streamStateFinWait1) + } else { + s.setState(streamStateLastAck) + } + + s.write(muxPacketFin, nil) + return nil + } + + func (s *Stream) Read(p []byte) (int, error) { + return s.reader.Read(p) + } + + func (s *Stream) Write(p []byte) (int, error) { + s.mu.Lock() + state := s.state + s.mu.Unlock() + + if state != streamStateEstablished && state != streamStateCloseWait { + return 0, fmt.Errorf("Stream %d in bad state to send: %d", s.id, state) + } + + return s.write(muxPacketData, p) + } + + func (s *Stream) closeWriter() { + s.writeCh <- nil + } + + func (s *Stream) setState(state streamState) { + //log.Printf("[TRACE] %p: Stream %d (%s) went to state %d", s.mux, s.id, s.from, state) + s.state = state + s.stateUpdated = time.Now().UTC() + for ch, _ := range s.stateChange { + select { + case ch <- state: + default: + } } } - }() - return stream -} + func (s *Stream) waitState(target streamState) error { + // Register a state change listener to wait for changes + stateCh := make(chan streamState, 10) + s.stateChange[stateCh] = struct{}{} + s.mu.Unlock() -func (s *Stream) Close() error { - s.mu.Lock() - defer s.mu.Unlock() + defer func() { + s.mu.Lock() + delete(s.stateChange, stateCh) + }() - if s.state != streamStateEstablished && s.state != streamStateCloseWait { - return fmt.Errorf("Stream in bad state: %d", s.state) - } - - if s.state == streamStateEstablished { - s.setState(streamStateFinWait1) - } else { - s.setState(streamStateLastAck) - } - - s.write(muxPacketFin, nil) - return nil -} - -func (s *Stream) Read(p []byte) (int, error) { - return s.reader.Read(p) -} - -func (s *Stream) Write(p []byte) (int, error) { - s.mu.Lock() - state := s.state - s.mu.Unlock() - - if state != streamStateEstablished && state != streamStateCloseWait { - return 0, fmt.Errorf("Stream %d in bad state to send: %d", s.id, state) - } - - return s.write(muxPacketData, p) -} - -func (s *Stream) closeWriter() { - s.writeCh <- nil -} - -func (s *Stream) setState(state streamState) { - //log.Printf("[TRACE] %p: Stream %d (%s) went to state %d", s.mux, s.id, s.from, state) - s.state = state - s.stateUpdated = time.Now().UTC() - for ch, _ := range s.stateChange { - select { - case ch <- state: - default: + state := <-stateCh + if state == target { + return nil + } else { + return fmt.Errorf("Stream %d went to bad state: %d", s.id, state) + } } - } -} -func (s *Stream) waitState(target streamState) error { - // Register a state change listener to wait for changes - stateCh := make(chan streamState, 10) - s.stateChange[stateCh] = struct{}{} - s.mu.Unlock() - - defer func() { - s.mu.Lock() - delete(s.stateChange, stateCh) - }() - - state := <-stateCh - if state == target { - return nil - } else { - return fmt.Errorf("Stream %d went to bad state: %d", s.id, state) - } -} - -func (s *Stream) write(dataType muxPacketType, p []byte) (int, error) { - return s.mux.write(s.from, s.id, dataType, p) -} + func (s *Stream) write(dataType muxPacketType, p []byte) (int, error) { + return s.mux.write(s.from, s.id, dataType, p) + }