packer/rpc: muxconn can't use stream ID 0 ever

This commit is contained in:
Mitchell Hashimoto 2013-12-30 21:03:10 -08:00
parent 7177cc149f
commit ae37050e8a
2 changed files with 248 additions and 246 deletions

View File

@ -1,6 +1,8 @@
## 0.5.1 (unreleased)
BUG FIXES:
* core: If a stream ID loops around, don't let it use stream ID 0 [GH-767]
## 0.5.0 (12/30/2013)

View File

@ -191,13 +191,13 @@ func (m *MuxConn) NextId() uint32 {
m.muAccept.Lock()
defer m.muAccept.Unlock()
// We never use stream ID 0 because 0 is the zero value of a uint32
// and we want to reserve that for "not in use"
if m.curId == 0 {
m.curId = 1
}
for {
// We never use stream ID 0 because 0 is the zero value of a uint32
// and we want to reserve that for "not in use"
if m.curId == 0 {
m.curId = 1
}
result := m.curId
m.curId += 1
if _, ok := m.streamsAccept[result]; !ok {
@ -319,263 +319,263 @@ func (m *MuxConn) loop() {
log.Printf(
"[WARN] %p: Non-existent stream %d (%s) received packer %d",
m, id, from, packetType)
continue
continue
}
//log.Printf("[TRACE] %p: Stream %d (%s) received packet %d", m, id, from, packetType)
switch packetType {
case muxPacketSyn:
// If the stream is nil, this is the only case where we'll
// automatically create the stream struct.
if stream == nil {
var ok bool
m.muAccept.Lock()
stream, ok = m.streamsAccept[id]
if !ok {
stream = newStream(muxPacketFromAccept, id, m)
m.streamsAccept[id] = stream
}
m.muAccept.Unlock()
}
stream.mu.Lock()
switch stream.state {
case streamStateClosed:
fallthrough
case streamStateListen:
stream.setState(streamStateSynRecv)
default:
log.Printf("[ERR] Syn received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
case muxPacketAck:
stream.mu.Lock()
switch stream.state {
case streamStateSynRecv:
stream.setState(streamStateEstablished)
case streamStateFinWait1:
stream.setState(streamStateFinWait2)
case streamStateLastAck:
stream.closeWriter()
fallthrough
case streamStateClosing:
stream.setState(streamStateClosed)
default:
log.Printf("[ERR] Ack received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
case muxPacketSynAck:
stream.mu.Lock()
switch stream.state {
case streamStateSynSent:
stream.setState(streamStateEstablished)
default:
log.Printf("[ERR] SynAck received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
case muxPacketFin:
stream.mu.Lock()
switch stream.state {
case streamStateEstablished:
stream.closeWriter()
stream.setState(streamStateCloseWait)
stream.write(muxPacketAck, nil)
case streamStateFinWait2:
stream.closeWriter()
stream.setState(streamStateClosed)
stream.write(muxPacketAck, nil)
case streamStateFinWait1:
stream.closeWriter()
stream.setState(streamStateClosing)
stream.write(muxPacketAck, nil)
default:
log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state)
}
stream.mu.Unlock()
case muxPacketData:
stream.mu.Lock()
switch stream.state {
case streamStateFinWait1:
fallthrough
case streamStateFinWait2:
fallthrough
case streamStateEstablished:
if len(data) > 0 {
select {
case stream.writeCh <- data:
default:
panic(fmt.Sprintf(
"Failed to write data, buffer full for stream %d", id))
}
}
default:
log.Printf("[ERR] Data received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
}
}
}
//log.Printf("[TRACE] %p: Stream %d (%s) received packet %d", m, id, from, packetType)
switch packetType {
case muxPacketSyn:
// If the stream is nil, this is the only case where we'll
// automatically create the stream struct.
if stream == nil {
var ok bool
func (m *MuxConn) write(from muxPacketFrom, id uint32, dataType muxPacketType, p []byte) (int, error) {
m.wlock.Lock()
defer m.wlock.Unlock()
m.muAccept.Lock()
stream, ok = m.streamsAccept[id]
if !ok {
stream = newStream(muxPacketFromAccept, id, m)
m.streamsAccept[id] = stream
}
m.muAccept.Unlock()
if err := binary.Write(m.rwc, binary.BigEndian, from); err != nil {
return 0, err
}
if err := binary.Write(m.rwc, binary.BigEndian, id); err != nil {
return 0, err
}
if err := binary.Write(m.rwc, binary.BigEndian, byte(dataType)); err != nil {
return 0, err
}
if err := binary.Write(m.rwc, binary.BigEndian, int32(len(p))); err != nil {
return 0, err
}
if len(p) == 0 {
return 0, nil
}
return m.rwc.Write(p)
}
stream.mu.Lock()
switch stream.state {
case streamStateClosed:
fallthrough
case streamStateListen:
stream.setState(streamStateSynRecv)
default:
log.Printf("[ERR] Syn received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
case muxPacketAck:
stream.mu.Lock()
switch stream.state {
case streamStateSynRecv:
stream.setState(streamStateEstablished)
case streamStateFinWait1:
stream.setState(streamStateFinWait2)
case streamStateLastAck:
stream.closeWriter()
fallthrough
case streamStateClosing:
stream.setState(streamStateClosed)
default:
log.Printf("[ERR] Ack received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
case muxPacketSynAck:
stream.mu.Lock()
switch stream.state {
case streamStateSynSent:
stream.setState(streamStateEstablished)
default:
log.Printf("[ERR] SynAck received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
case muxPacketFin:
stream.mu.Lock()
switch stream.state {
case streamStateEstablished:
stream.closeWriter()
stream.setState(streamStateCloseWait)
stream.write(muxPacketAck, nil)
case streamStateFinWait2:
stream.closeWriter()
stream.setState(streamStateClosed)
stream.write(muxPacketAck, nil)
case streamStateFinWait1:
stream.closeWriter()
stream.setState(streamStateClosing)
stream.write(muxPacketAck, nil)
default:
log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state)
}
stream.mu.Unlock()
// Stream is a single stream of data and implements io.ReadWriteCloser.
// A Stream is full-duplex so you can write data as well as read data.
type Stream struct {
from muxPacketFrom
id uint32
mux *MuxConn
reader io.Reader
state streamState
stateChange map[chan<- streamState]struct{}
stateUpdated time.Time
mu sync.Mutex
writeCh chan<- []byte
}
case muxPacketData:
stream.mu.Lock()
switch stream.state {
case streamStateFinWait1:
fallthrough
case streamStateFinWait2:
fallthrough
case streamStateEstablished:
if len(data) > 0 {
select {
case stream.writeCh <- data:
default:
panic(fmt.Sprintf(
"Failed to write data, buffer full for stream %d", id))
type streamState byte
const (
streamStateClosed streamState = iota
streamStateListen
streamStateSynRecv
streamStateSynSent
streamStateEstablished
streamStateFinWait1
streamStateFinWait2
streamStateCloseWait
streamStateClosing
streamStateLastAck
)
func newStream(from muxPacketFrom, id uint32, m *MuxConn) *Stream {
// Create the stream object and channel where data will be sent to
dataR, dataW := io.Pipe()
writeCh := make(chan []byte, 4096)
// Set the data channel so we can write to it.
stream := &Stream{
from: from,
id: id,
mux: m,
reader: dataR,
writeCh: writeCh,
stateChange: make(map[chan<- streamState]struct{}),
}
stream.setState(streamStateClosed)
// Start the goroutine that will read from the queue and write
// data out.
go func() {
defer dataW.Close()
for {
data := <-writeCh
if data == nil {
// A nil is a tombstone letting us know we're done
// accepting data.
return
}
if _, err := dataW.Write(data); err != nil {
return
}
}
default:
log.Printf("[ERR] Data received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
}()
return stream
}
}
}
func (m *MuxConn) write(from muxPacketFrom, id uint32, dataType muxPacketType, p []byte) (int, error) {
m.wlock.Lock()
defer m.wlock.Unlock()
func (s *Stream) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if err := binary.Write(m.rwc, binary.BigEndian, from); err != nil {
return 0, err
}
if err := binary.Write(m.rwc, binary.BigEndian, id); err != nil {
return 0, err
}
if err := binary.Write(m.rwc, binary.BigEndian, byte(dataType)); err != nil {
return 0, err
}
if err := binary.Write(m.rwc, binary.BigEndian, int32(len(p))); err != nil {
return 0, err
}
if len(p) == 0 {
return 0, nil
}
return m.rwc.Write(p)
}
// Stream is a single stream of data and implements io.ReadWriteCloser.
// A Stream is full-duplex so you can write data as well as read data.
type Stream struct {
from muxPacketFrom
id uint32
mux *MuxConn
reader io.Reader
state streamState
stateChange map[chan<- streamState]struct{}
stateUpdated time.Time
mu sync.Mutex
writeCh chan<- []byte
}
type streamState byte
const (
streamStateClosed streamState = iota
streamStateListen
streamStateSynRecv
streamStateSynSent
streamStateEstablished
streamStateFinWait1
streamStateFinWait2
streamStateCloseWait
streamStateClosing
streamStateLastAck
)
func newStream(from muxPacketFrom, id uint32, m *MuxConn) *Stream {
// Create the stream object and channel where data will be sent to
dataR, dataW := io.Pipe()
writeCh := make(chan []byte, 4096)
// Set the data channel so we can write to it.
stream := &Stream{
from: from,
id: id,
mux: m,
reader: dataR,
writeCh: writeCh,
stateChange: make(map[chan<- streamState]struct{}),
}
stream.setState(streamStateClosed)
// Start the goroutine that will read from the queue and write
// data out.
go func() {
defer dataW.Close()
for {
data := <-writeCh
if data == nil {
// A nil is a tombstone letting us know we're done
// accepting data.
return
if s.state != streamStateEstablished && s.state != streamStateCloseWait {
return fmt.Errorf("Stream in bad state: %d", s.state)
}
if _, err := dataW.Write(data); err != nil {
return
if s.state == streamStateEstablished {
s.setState(streamStateFinWait1)
} else {
s.setState(streamStateLastAck)
}
s.write(muxPacketFin, nil)
return nil
}
func (s *Stream) Read(p []byte) (int, error) {
return s.reader.Read(p)
}
func (s *Stream) Write(p []byte) (int, error) {
s.mu.Lock()
state := s.state
s.mu.Unlock()
if state != streamStateEstablished && state != streamStateCloseWait {
return 0, fmt.Errorf("Stream %d in bad state to send: %d", s.id, state)
}
return s.write(muxPacketData, p)
}
func (s *Stream) closeWriter() {
s.writeCh <- nil
}
func (s *Stream) setState(state streamState) {
//log.Printf("[TRACE] %p: Stream %d (%s) went to state %d", s.mux, s.id, s.from, state)
s.state = state
s.stateUpdated = time.Now().UTC()
for ch, _ := range s.stateChange {
select {
case ch <- state:
default:
}
}
}
}()
return stream
}
func (s *Stream) waitState(target streamState) error {
// Register a state change listener to wait for changes
stateCh := make(chan streamState, 10)
s.stateChange[stateCh] = struct{}{}
s.mu.Unlock()
func (s *Stream) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
defer func() {
s.mu.Lock()
delete(s.stateChange, stateCh)
}()
if s.state != streamStateEstablished && s.state != streamStateCloseWait {
return fmt.Errorf("Stream in bad state: %d", s.state)
}
if s.state == streamStateEstablished {
s.setState(streamStateFinWait1)
} else {
s.setState(streamStateLastAck)
}
s.write(muxPacketFin, nil)
return nil
}
func (s *Stream) Read(p []byte) (int, error) {
return s.reader.Read(p)
}
func (s *Stream) Write(p []byte) (int, error) {
s.mu.Lock()
state := s.state
s.mu.Unlock()
if state != streamStateEstablished && state != streamStateCloseWait {
return 0, fmt.Errorf("Stream %d in bad state to send: %d", s.id, state)
}
return s.write(muxPacketData, p)
}
func (s *Stream) closeWriter() {
s.writeCh <- nil
}
func (s *Stream) setState(state streamState) {
//log.Printf("[TRACE] %p: Stream %d (%s) went to state %d", s.mux, s.id, s.from, state)
s.state = state
s.stateUpdated = time.Now().UTC()
for ch, _ := range s.stateChange {
select {
case ch <- state:
default:
state := <-stateCh
if state == target {
return nil
} else {
return fmt.Errorf("Stream %d went to bad state: %d", s.id, state)
}
}
}
}
func (s *Stream) waitState(target streamState) error {
// Register a state change listener to wait for changes
stateCh := make(chan streamState, 10)
s.stateChange[stateCh] = struct{}{}
s.mu.Unlock()
defer func() {
s.mu.Lock()
delete(s.stateChange, stateCh)
}()
state := <-stateCh
if state == target {
return nil
} else {
return fmt.Errorf("Stream %d went to bad state: %d", s.id, state)
}
}
func (s *Stream) write(dataType muxPacketType, p []byte) (int, error) {
return s.mux.write(s.from, s.id, dataType, p)
}
func (s *Stream) write(dataType muxPacketType, p []byte) (int, error) {
return s.mux.write(s.from, s.id, dataType, p)
}