From 5c6831080cd906759b1d348a0f8936d5230826a6 Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Sun, 8 Dec 2013 18:30:29 -0800 Subject: [PATCH] packer/rpc: close the streams when the underlying rwc closes --- packer/rpc/muxconn.go | 25 ++++++++-- packer/rpc/muxconn_test.go | 100 +++++++++++++++++++++++++++++-------- 2 files changed, 101 insertions(+), 24 deletions(-) diff --git a/packer/rpc/muxconn.go b/packer/rpc/muxconn.go index b1f1df2e5..91b9b08d0 100644 --- a/packer/rpc/muxconn.go +++ b/packer/rpc/muxconn.go @@ -13,14 +13,14 @@ import ( // to actually act as a server as well. // // MuxConn works using a fairly dumb multiplexing technique of simply -// prefixing each message with whether it is on stream 0 (the original) -// or stream 1 (the client "server"). +// prefixing each message with what stream it is on along with the length +// of the data. // // This can likely be abstracted to N streams, but by choosing only two // we decided to cut a lot of corners and make this easily usable for Packer. type MuxConn struct { rwc io.ReadWriteCloser - streams map[byte]io.Writer + streams map[byte]io.WriteCloser mu sync.RWMutex wlock sync.Mutex } @@ -28,7 +28,7 @@ type MuxConn struct { func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn { m := &MuxConn{ rwc: rwc, - streams: make(map[byte]io.Writer), + streams: make(map[byte]io.WriteCloser), } go m.loop() @@ -36,6 +36,21 @@ func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn { return m } +// Close closes the underlying io.ReadWriteCloser. This will also close +// all streams that are open. +func (m *MuxConn) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + + // Close all the streams + for _, w := range m.streams { + w.Close() + } + m.streams = make(map[byte]io.WriteCloser) + + return m.rwc.Close() +} + // Stream returns a io.ReadWriteCloser that will only read/write to the // given stream ID. No handshake is done so if the remote end does not // have a stream open with the same ID, then the messages will simply @@ -67,6 +82,8 @@ func (m *MuxConn) Stream(id byte) (io.ReadWriteCloser, error) { } func (m *MuxConn) loop() { + defer m.Close() + for { var id byte var length int32 diff --git a/packer/rpc/muxconn_test.go b/packer/rpc/muxconn_test.go index 4493e1c5f..9c7b69eb7 100644 --- a/packer/rpc/muxconn_test.go +++ b/packer/rpc/muxconn_test.go @@ -17,12 +17,43 @@ func readStream(t *testing.T, s io.Reader) string { return string(data[0:n]) } -func TestMuxConn(t *testing.T) { +func testMux(t *testing.T) (client *MuxConn, server *MuxConn) { l, err := net.Listen("tcp", ":0") if err != nil { t.Fatalf("err: %s", err) } + // Server side + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) + conn, err := l.Accept() + l.Close() + if err != nil { + t.Fatalf("err: %s", err) + } + + server = NewMuxConn(conn) + }() + + // Client side + conn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatalf("err: %s", err) + } + client = NewMuxConn(conn) + + // Wait for the server + <-doneCh + + return +} + +func TestMuxConn(t *testing.T) { + client, server := testMux(t) + defer client.Close() + defer server.Close() + // When the server is done doneCh := make(chan struct{}) readyCh := make(chan struct{}) @@ -30,20 +61,13 @@ func TestMuxConn(t *testing.T) { // The server side go func() { defer close(doneCh) - conn, err := l.Accept() - l.Close() - if err != nil { - t.Fatalf("err: %s", err) - } - defer conn.Close() - mux := NewMuxConn(conn) - s0, err := mux.Stream(0) + s0, err := server.Stream(0) if err != nil { t.Fatalf("err: %s", err) } - s1, err := mux.Stream(1) + s1, err := server.Stream(1) if err != nil { t.Fatalf("err: %s", err) } @@ -72,20 +96,12 @@ func TestMuxConn(t *testing.T) { wg.Wait() }() - // Client side - conn, err := net.Dial("tcp", l.Addr().String()) - if err != nil { - t.Fatalf("err: %s", err) - } - defer conn.Close() - - mux := NewMuxConn(conn) - s0, err := mux.Stream(0) + s0, err := client.Stream(0) if err != nil { t.Fatalf("err: %s", err) } - s1, err := mux.Stream(1) + s1, err := client.Stream(1) if err != nil { t.Fatalf("err: %s", err) } @@ -103,3 +119,47 @@ func TestMuxConn(t *testing.T) { // Wait for the server to be done <-doneCh } + +func TestMuxConn_clientClosesStreams(t *testing.T) { + client, server := testMux(t) + defer client.Close() + defer server.Close() + + s0, err := client.Stream(0) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := client.Close(); err != nil { + t.Fatalf("err: %s", err) + } + + // This should block forever since we never write onto this stream. + var data [1024]byte + _, err = s0.Read(data[:]) + if err != io.EOF { + t.Fatalf("err: %s", err) + } +} + +func TestMuxConn_serverClosesStreams(t *testing.T) { + client, server := testMux(t) + defer client.Close() + defer server.Close() + + s0, err := client.Stream(0) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := server.Close(); err != nil { + t.Fatalf("err: %s", err) + } + + // This should block forever since we never write onto this stream. + var data [1024]byte + _, err = s0.Read(data[:]) + if err != io.EOF { + t.Fatalf("err: %s", err) + } +}