packer/rpc: Clean up old streams [GH-708]
This commit is contained in:
parent
b1f07dcbe0
commit
bec978fd8b
|
@ -27,6 +27,7 @@ type MuxConn struct {
|
|||
streams map[uint32]*Stream
|
||||
mu sync.RWMutex
|
||||
wlock sync.Mutex
|
||||
doneCh chan struct{}
|
||||
}
|
||||
|
||||
type muxPacketType byte
|
||||
|
@ -44,8 +45,10 @@ func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn {
|
|||
m := &MuxConn{
|
||||
rwc: rwc,
|
||||
streams: make(map[uint32]*Stream),
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
go m.cleaner()
|
||||
go m.loop()
|
||||
|
||||
return m
|
||||
|
@ -211,18 +214,45 @@ func (m *MuxConn) openStream(id uint32) (*Stream, error) {
|
|||
return m.streams[id], nil
|
||||
}
|
||||
|
||||
func (m *MuxConn) cleaner() {
|
||||
for {
|
||||
done := false
|
||||
select {
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
case <-m.doneCh:
|
||||
done = true
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
for id, s := range m.streams {
|
||||
s.mu.Lock()
|
||||
if s.state == streamStateClosed {
|
||||
delete(m.streams, id)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
if done {
|
||||
for _, s := range m.streams {
|
||||
s.mu.Lock()
|
||||
s.closeWriter()
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
if done {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MuxConn) loop() {
|
||||
// Force close every stream that we know about when we exit so
|
||||
// that they all read EOF and don't block forever.
|
||||
defer func() {
|
||||
log.Printf("[INFO] Mux connection loop exiting")
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
for _, w := range m.streams {
|
||||
w.mu.Lock()
|
||||
w.remoteClose()
|
||||
w.mu.Unlock()
|
||||
}
|
||||
close(m.doneCh)
|
||||
}()
|
||||
|
||||
var id uint32
|
||||
|
@ -277,6 +307,11 @@ func (m *MuxConn) loop() {
|
|||
stream.setState(streamStateEstablished)
|
||||
case streamStateFinWait1:
|
||||
stream.setState(streamStateFinWait2)
|
||||
case streamStateLastAck:
|
||||
stream.closeWriter()
|
||||
fallthrough
|
||||
case streamStateClosing:
|
||||
stream.setState(streamStateClosed)
|
||||
default:
|
||||
log.Printf("[ERR] Ack received for stream in state: %d", stream.state)
|
||||
}
|
||||
|
@ -294,20 +329,17 @@ func (m *MuxConn) loop() {
|
|||
stream.mu.Lock()
|
||||
switch stream.state {
|
||||
case streamStateEstablished:
|
||||
stream.closeWriter()
|
||||
stream.setState(streamStateCloseWait)
|
||||
m.write(id, muxPacketAck, nil)
|
||||
|
||||
// Close the writer on our end since we won't receive any
|
||||
// more data.
|
||||
stream.writeCh <- nil
|
||||
case streamStateFinWait1:
|
||||
fallthrough
|
||||
case streamStateFinWait2:
|
||||
stream.remoteClose()
|
||||
|
||||
m.mu.Lock()
|
||||
delete(m.streams, stream.id)
|
||||
m.mu.Unlock()
|
||||
stream.closeWriter()
|
||||
stream.setState(streamStateClosed)
|
||||
m.write(id, muxPacketAck, nil)
|
||||
case streamStateFinWait1:
|
||||
stream.closeWriter()
|
||||
stream.setState(streamStateClosing)
|
||||
m.write(id, muxPacketAck, nil)
|
||||
default:
|
||||
log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state)
|
||||
}
|
||||
|
@ -377,6 +409,8 @@ const (
|
|||
streamStateFinWait1
|
||||
streamStateFinWait2
|
||||
streamStateCloseWait
|
||||
streamStateClosing
|
||||
streamStateLastAck
|
||||
)
|
||||
|
||||
func (s *Stream) Close() error {
|
||||
|
@ -390,7 +424,7 @@ func (s *Stream) Close() error {
|
|||
if s.state == streamStateEstablished {
|
||||
s.setState(streamStateFinWait1)
|
||||
} else {
|
||||
s.remoteClose()
|
||||
s.setState(streamStateLastAck)
|
||||
}
|
||||
|
||||
s.mux.write(s.id, muxPacketFin, nil)
|
||||
|
@ -413,8 +447,7 @@ func (s *Stream) Write(p []byte) (int, error) {
|
|||
return s.mux.write(s.id, muxPacketData, p)
|
||||
}
|
||||
|
||||
func (s *Stream) remoteClose() {
|
||||
s.setState(streamStateClosed)
|
||||
func (s *Stream) closeWriter() {
|
||||
s.writeCh <- nil
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue