packer/rpc: implement proper close_wait state
This commit is contained in:
parent
e4dbad330d
commit
7372c32b6b
|
@ -150,6 +150,8 @@ func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) {
|
||||||
// and it closed all within the time period we wait above.
|
// and it closed all within the time period we wait above.
|
||||||
// This case will be fixed when we have edge-triggered checks.
|
// This case will be fixed when we have edge-triggered checks.
|
||||||
fallthrough
|
fallthrough
|
||||||
|
case streamStateCloseWait:
|
||||||
|
fallthrough
|
||||||
case streamStateEstablished:
|
case streamStateEstablished:
|
||||||
stream.mu.Unlock()
|
stream.mu.Unlock()
|
||||||
return stream, nil
|
return stream, nil
|
||||||
|
@ -274,7 +276,7 @@ func (m *MuxConn) loop() {
|
||||||
case streamStateSynSent:
|
case streamStateSynSent:
|
||||||
stream.setState(streamStateEstablished)
|
stream.setState(streamStateEstablished)
|
||||||
case streamStateFinWait1:
|
case streamStateFinWait1:
|
||||||
stream.remoteClose()
|
stream.setState(streamStateFinWait2)
|
||||||
default:
|
default:
|
||||||
log.Printf("[ERR] Ack received for stream in state: %d", stream.state)
|
log.Printf("[ERR] Ack received for stream in state: %d", stream.state)
|
||||||
}
|
}
|
||||||
|
@ -294,9 +296,15 @@ func (m *MuxConn) loop() {
|
||||||
stream.mu.Lock()
|
stream.mu.Lock()
|
||||||
switch stream.state {
|
switch stream.state {
|
||||||
case streamStateEstablished:
|
case streamStateEstablished:
|
||||||
|
stream.setState(streamStateCloseWait)
|
||||||
m.write(id, muxPacketAck, nil)
|
m.write(id, muxPacketAck, nil)
|
||||||
fallthrough
|
|
||||||
|
// Close the writer on our end since we won't receive any
|
||||||
|
// more data.
|
||||||
|
stream.writeCh <- nil
|
||||||
case streamStateFinWait1:
|
case streamStateFinWait1:
|
||||||
|
fallthrough
|
||||||
|
case streamStateFinWait2:
|
||||||
stream.remoteClose()
|
stream.remoteClose()
|
||||||
|
|
||||||
// Remove this stream from being active so that it
|
// Remove this stream from being active so that it
|
||||||
|
@ -364,34 +372,26 @@ const (
|
||||||
streamStateSynSent
|
streamStateSynSent
|
||||||
streamStateEstablished
|
streamStateEstablished
|
||||||
streamStateFinWait1
|
streamStateFinWait1
|
||||||
|
streamStateFinWait2
|
||||||
|
streamStateCloseWait
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Stream) Close() error {
|
func (s *Stream) Close() error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
if s.state != streamStateEstablished {
|
defer s.mu.Unlock()
|
||||||
s.mu.Unlock()
|
|
||||||
|
if s.state != streamStateEstablished && s.state != streamStateCloseWait {
|
||||||
return fmt.Errorf("Stream in bad state: %d", s.state)
|
return fmt.Errorf("Stream in bad state: %d", s.state)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := s.mux.write(s.id, muxPacketFin, nil); err != nil {
|
if _, err := s.mux.write(s.id, muxPacketFin, nil); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.setState(streamStateFinWait1)
|
|
||||||
s.mu.Unlock()
|
|
||||||
|
|
||||||
for {
|
if s.state == streamStateEstablished {
|
||||||
time.Sleep(50 * time.Millisecond)
|
s.setState(streamStateFinWait1)
|
||||||
s.mu.Lock()
|
} else {
|
||||||
switch s.state {
|
s.remoteClose()
|
||||||
case streamStateFinWait1:
|
|
||||||
s.mu.Unlock()
|
|
||||||
case streamStateClosed:
|
|
||||||
s.mu.Unlock()
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
return fmt.Errorf("Stream %d went to bad state: %d", s.id, s.state)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -118,18 +118,20 @@ func TestMuxConn_clientClosesStreams(t *testing.T) {
|
||||||
client, server := testMux(t)
|
client, server := testMux(t)
|
||||||
defer client.Close()
|
defer client.Close()
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
go server.Accept(0)
|
|
||||||
|
go func() {
|
||||||
|
conn, err := server.Accept(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
conn.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
s0, err := client.Dial(0)
|
s0, err := client.Dial(0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := client.Close(); err != nil {
|
|
||||||
t.Fatalf("err: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// This should block forever since we never write onto this stream.
|
|
||||||
var data [1024]byte
|
var data [1024]byte
|
||||||
_, err = s0.Read(data[:])
|
_, err = s0.Read(data[:])
|
||||||
if err != io.EOF {
|
if err != io.EOF {
|
||||||
|
|
Loading…
Reference in New Issue