packer/rpc: Clean up old streams [GH-708]
This commit is contained in:
parent
b1f07dcbe0
commit
bec978fd8b
|
@ -27,6 +27,7 @@ type MuxConn struct {
|
||||||
streams map[uint32]*Stream
|
streams map[uint32]*Stream
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
wlock sync.Mutex
|
wlock sync.Mutex
|
||||||
|
doneCh chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
type muxPacketType byte
|
type muxPacketType byte
|
||||||
|
@ -44,8 +45,10 @@ func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn {
|
||||||
m := &MuxConn{
|
m := &MuxConn{
|
||||||
rwc: rwc,
|
rwc: rwc,
|
||||||
streams: make(map[uint32]*Stream),
|
streams: make(map[uint32]*Stream),
|
||||||
|
doneCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go m.cleaner()
|
||||||
go m.loop()
|
go m.loop()
|
||||||
|
|
||||||
return m
|
return m
|
||||||
|
@ -211,18 +214,45 @@ func (m *MuxConn) openStream(id uint32) (*Stream, error) {
|
||||||
return m.streams[id], nil
|
return m.streams[id], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MuxConn) cleaner() {
|
||||||
|
for {
|
||||||
|
done := false
|
||||||
|
select {
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
case <-m.doneCh:
|
||||||
|
done = true
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
for id, s := range m.streams {
|
||||||
|
s.mu.Lock()
|
||||||
|
if s.state == streamStateClosed {
|
||||||
|
delete(m.streams, id)
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
if done {
|
||||||
|
for _, s := range m.streams {
|
||||||
|
s.mu.Lock()
|
||||||
|
s.closeWriter()
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
if done {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (m *MuxConn) loop() {
|
func (m *MuxConn) loop() {
|
||||||
// Force close every stream that we know about when we exit so
|
// Force close every stream that we know about when we exit so
|
||||||
// that they all read EOF and don't block forever.
|
// that they all read EOF and don't block forever.
|
||||||
defer func() {
|
defer func() {
|
||||||
log.Printf("[INFO] Mux connection loop exiting")
|
log.Printf("[INFO] Mux connection loop exiting")
|
||||||
m.mu.Lock()
|
close(m.doneCh)
|
||||||
defer m.mu.Unlock()
|
|
||||||
for _, w := range m.streams {
|
|
||||||
w.mu.Lock()
|
|
||||||
w.remoteClose()
|
|
||||||
w.mu.Unlock()
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var id uint32
|
var id uint32
|
||||||
|
@ -277,6 +307,11 @@ func (m *MuxConn) loop() {
|
||||||
stream.setState(streamStateEstablished)
|
stream.setState(streamStateEstablished)
|
||||||
case streamStateFinWait1:
|
case streamStateFinWait1:
|
||||||
stream.setState(streamStateFinWait2)
|
stream.setState(streamStateFinWait2)
|
||||||
|
case streamStateLastAck:
|
||||||
|
stream.closeWriter()
|
||||||
|
fallthrough
|
||||||
|
case streamStateClosing:
|
||||||
|
stream.setState(streamStateClosed)
|
||||||
default:
|
default:
|
||||||
log.Printf("[ERR] Ack received for stream in state: %d", stream.state)
|
log.Printf("[ERR] Ack received for stream in state: %d", stream.state)
|
||||||
}
|
}
|
||||||
|
@ -294,20 +329,17 @@ func (m *MuxConn) loop() {
|
||||||
stream.mu.Lock()
|
stream.mu.Lock()
|
||||||
switch stream.state {
|
switch stream.state {
|
||||||
case streamStateEstablished:
|
case streamStateEstablished:
|
||||||
|
stream.closeWriter()
|
||||||
stream.setState(streamStateCloseWait)
|
stream.setState(streamStateCloseWait)
|
||||||
m.write(id, muxPacketAck, nil)
|
m.write(id, muxPacketAck, nil)
|
||||||
|
|
||||||
// Close the writer on our end since we won't receive any
|
|
||||||
// more data.
|
|
||||||
stream.writeCh <- nil
|
|
||||||
case streamStateFinWait1:
|
|
||||||
fallthrough
|
|
||||||
case streamStateFinWait2:
|
case streamStateFinWait2:
|
||||||
stream.remoteClose()
|
stream.closeWriter()
|
||||||
|
stream.setState(streamStateClosed)
|
||||||
m.mu.Lock()
|
m.write(id, muxPacketAck, nil)
|
||||||
delete(m.streams, stream.id)
|
case streamStateFinWait1:
|
||||||
m.mu.Unlock()
|
stream.closeWriter()
|
||||||
|
stream.setState(streamStateClosing)
|
||||||
|
m.write(id, muxPacketAck, nil)
|
||||||
default:
|
default:
|
||||||
log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state)
|
log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state)
|
||||||
}
|
}
|
||||||
|
@ -377,6 +409,8 @@ const (
|
||||||
streamStateFinWait1
|
streamStateFinWait1
|
||||||
streamStateFinWait2
|
streamStateFinWait2
|
||||||
streamStateCloseWait
|
streamStateCloseWait
|
||||||
|
streamStateClosing
|
||||||
|
streamStateLastAck
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Stream) Close() error {
|
func (s *Stream) Close() error {
|
||||||
|
@ -390,7 +424,7 @@ func (s *Stream) Close() error {
|
||||||
if s.state == streamStateEstablished {
|
if s.state == streamStateEstablished {
|
||||||
s.setState(streamStateFinWait1)
|
s.setState(streamStateFinWait1)
|
||||||
} else {
|
} else {
|
||||||
s.remoteClose()
|
s.setState(streamStateLastAck)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.mux.write(s.id, muxPacketFin, nil)
|
s.mux.write(s.id, muxPacketFin, nil)
|
||||||
|
@ -413,8 +447,7 @@ func (s *Stream) Write(p []byte) (int, error) {
|
||||||
return s.mux.write(s.id, muxPacketData, p)
|
return s.mux.write(s.id, muxPacketData, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Stream) remoteClose() {
|
func (s *Stream) closeWriter() {
|
||||||
s.setState(streamStateClosed)
|
|
||||||
s.writeCh <- nil
|
s.writeCh <- nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue