packer/rpc: better close states
This commit is contained in:
parent
0a6061fd0b
commit
bcebec8fc3
@ -21,7 +21,7 @@ type MuxConn struct {
|
||||
curId uint32
|
||||
rwc io.ReadWriteCloser
|
||||
streams map[uint32]*Stream
|
||||
mu sync.Mutex
|
||||
mu sync.RWMutex
|
||||
wlock sync.Mutex
|
||||
}
|
||||
|
||||
@ -48,8 +48,8 @@ func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn {
|
||||
// Close closes the underlying io.ReadWriteCloser. This will also close
|
||||
// all streams that are open.
|
||||
func (m *MuxConn) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
// Close all the streams
|
||||
for _, w := range m.streams {
|
||||
@ -94,12 +94,17 @@ func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) {
|
||||
switch stream.state {
|
||||
case streamStateListen:
|
||||
stream.mu.Unlock()
|
||||
case streamStateClosed:
|
||||
// This can happen if it becomes established, some data is sent,
|
||||
// and it closed all within the time period we wait above.
|
||||
// This case will be fixed when we have edge-triggered checks.
|
||||
fallthrough
|
||||
case streamStateEstablished:
|
||||
stream.mu.Unlock()
|
||||
break ACCEPT_ESTABLISH_LOOP
|
||||
default:
|
||||
defer stream.mu.Unlock()
|
||||
return nil, fmt.Errorf("Stream went to bad state: %d", stream.state)
|
||||
return nil, fmt.Errorf("Stream %d went to bad state: %d", id, stream.state)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -140,12 +145,17 @@ func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) {
|
||||
switch stream.state {
|
||||
case streamStateSynSent:
|
||||
stream.mu.Unlock()
|
||||
case streamStateClosed:
|
||||
// This can happen if it becomes established, some data is sent,
|
||||
// and it closed all within the time period we wait above.
|
||||
// This case will be fixed when we have edge-triggered checks.
|
||||
fallthrough
|
||||
case streamStateEstablished:
|
||||
stream.mu.Unlock()
|
||||
return stream, nil
|
||||
default:
|
||||
defer stream.mu.Unlock()
|
||||
return nil, fmt.Errorf("Stream went to bad state: %d", stream.state)
|
||||
return nil, fmt.Errorf("Stream %d went to bad state: %d", id, stream.state)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -166,9 +176,21 @@ func (m *MuxConn) NextId() uint32 {
|
||||
}
|
||||
|
||||
func (m *MuxConn) openStream(id uint32) (*Stream, error) {
|
||||
// First grab a read-lock if we have the stream already we can
|
||||
// cheaply return it.
|
||||
m.mu.RLock()
|
||||
if stream, ok := m.streams[id]; ok {
|
||||
m.mu.RUnlock()
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
// Now acquire a full blown write lock so we can create the stream
|
||||
m.mu.RUnlock()
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// We have to check this again because there is a time period
|
||||
// above where we couldn't lost this lock.
|
||||
if stream, ok := m.streams[id]; ok {
|
||||
return stream, nil
|
||||
}
|
||||
@ -182,7 +204,6 @@ func (m *MuxConn) openStream(id uint32) (*Stream, error) {
|
||||
id: id,
|
||||
mux: m,
|
||||
reader: dataR,
|
||||
writer: dataW,
|
||||
writeCh: writeCh,
|
||||
}
|
||||
stream.setState(streamStateClosed)
|
||||
@ -190,8 +211,16 @@ func (m *MuxConn) openStream(id uint32) (*Stream, error) {
|
||||
// Start the goroutine that will read from the queue and write
|
||||
// data out.
|
||||
go func() {
|
||||
defer dataW.Close()
|
||||
|
||||
for {
|
||||
data := <-writeCh
|
||||
if data == nil {
|
||||
// A nil is a tombstone letting us know we're done
|
||||
// accepting data.
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := dataW.Write(data); err != nil {
|
||||
return
|
||||
}
|
||||
@ -237,12 +266,16 @@ func (m *MuxConn) loop() {
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG] Stream %d received packet %d", id, packetType)
|
||||
switch packetType {
|
||||
case muxPacketAck:
|
||||
stream.mu.Lock()
|
||||
if stream.state == streamStateSynSent {
|
||||
switch stream.state {
|
||||
case streamStateSynSent:
|
||||
stream.setState(streamStateEstablished)
|
||||
} else {
|
||||
case streamStateFinWait1:
|
||||
stream.remoteClose()
|
||||
default:
|
||||
log.Printf("[ERR] Ack received for stream in state: %d", stream.state)
|
||||
}
|
||||
stream.mu.Unlock()
|
||||
@ -259,13 +292,23 @@ func (m *MuxConn) loop() {
|
||||
stream.mu.Unlock()
|
||||
case muxPacketFin:
|
||||
stream.mu.Lock()
|
||||
stream.setState(streamStateClosed)
|
||||
stream.writer.Close()
|
||||
switch stream.state {
|
||||
case streamStateEstablished:
|
||||
m.write(id, muxPacketAck, nil)
|
||||
fallthrough
|
||||
case streamStateFinWait1:
|
||||
stream.remoteClose()
|
||||
|
||||
// Remove this stream from being active so that it
|
||||
// can be re-used
|
||||
m.mu.Lock()
|
||||
delete(m.streams, stream.id)
|
||||
m.mu.Unlock()
|
||||
default:
|
||||
log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state)
|
||||
}
|
||||
stream.mu.Unlock()
|
||||
|
||||
m.mu.Lock()
|
||||
delete(m.streams, stream.id)
|
||||
m.mu.Unlock()
|
||||
case muxPacketData:
|
||||
stream.mu.Lock()
|
||||
if stream.state == streamStateEstablished {
|
||||
@ -306,7 +349,6 @@ type Stream struct {
|
||||
id uint32
|
||||
mux *MuxConn
|
||||
reader io.Reader
|
||||
writer io.WriteCloser
|
||||
state streamState
|
||||
stateUpdated time.Time
|
||||
mu sync.Mutex
|
||||
@ -321,23 +363,37 @@ const (
|
||||
streamStateSynRecv
|
||||
streamStateSynSent
|
||||
streamStateEstablished
|
||||
streamStateFinWait
|
||||
streamStateFinWait1
|
||||
)
|
||||
|
||||
func (s *Stream) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.state != streamStateEstablished {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("Stream in bad state: %d", s.state)
|
||||
}
|
||||
|
||||
if _, err := s.mux.write(s.id, muxPacketFin, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
s.setState(streamStateFinWait1)
|
||||
s.mu.Unlock()
|
||||
|
||||
for {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
s.mu.Lock()
|
||||
switch s.state {
|
||||
case streamStateFinWait1:
|
||||
s.mu.Unlock()
|
||||
case streamStateClosed:
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
default:
|
||||
defer s.mu.Unlock()
|
||||
return fmt.Errorf("Stream %d went to bad state: %d", s.id, s.state)
|
||||
}
|
||||
}
|
||||
|
||||
s.setState(streamStateClosed)
|
||||
s.writer.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -346,9 +402,22 @@ func (s *Stream) Read(p []byte) (int, error) {
|
||||
}
|
||||
|
||||
func (s *Stream) Write(p []byte) (int, error) {
|
||||
s.mu.Lock()
|
||||
state := s.state
|
||||
s.mu.Unlock()
|
||||
|
||||
if state != streamStateEstablished {
|
||||
return 0, fmt.Errorf("Stream in bad state to send: %d", state)
|
||||
}
|
||||
|
||||
return s.mux.write(s.id, muxPacketData, p)
|
||||
}
|
||||
|
||||
func (s *Stream) remoteClose() {
|
||||
s.setState(streamStateClosed)
|
||||
s.writeCh <- nil
|
||||
}
|
||||
|
||||
func (s *Stream) setState(state streamState) {
|
||||
s.state = state
|
||||
s.stateUpdated = time.Now().UTC()
|
||||
|
Loading…
x
Reference in New Issue
Block a user