packer/rpc: close the streams when the underlying rwc closes
This commit is contained in:
parent
fe46093bcf
commit
5c6831080c
|
@ -13,14 +13,14 @@ import (
|
||||||
// to actually act as a server as well.
|
// to actually act as a server as well.
|
||||||
//
|
//
|
||||||
// MuxConn works using a fairly dumb multiplexing technique of simply
|
// MuxConn works using a fairly dumb multiplexing technique of simply
|
||||||
// prefixing each message with whether it is on stream 0 (the original)
|
// prefixing each message with what stream it is on along with the length
|
||||||
// or stream 1 (the client "server").
|
// of the data.
|
||||||
//
|
//
|
||||||
// This can likely be abstracted to N streams, but by choosing only two
|
// This can likely be abstracted to N streams, but by choosing only two
|
||||||
// we decided to cut a lot of corners and make this easily usable for Packer.
|
// we decided to cut a lot of corners and make this easily usable for Packer.
|
||||||
type MuxConn struct {
|
type MuxConn struct {
|
||||||
rwc io.ReadWriteCloser
|
rwc io.ReadWriteCloser
|
||||||
streams map[byte]io.Writer
|
streams map[byte]io.WriteCloser
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
wlock sync.Mutex
|
wlock sync.Mutex
|
||||||
}
|
}
|
||||||
|
@ -28,7 +28,7 @@ type MuxConn struct {
|
||||||
func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn {
|
func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn {
|
||||||
m := &MuxConn{
|
m := &MuxConn{
|
||||||
rwc: rwc,
|
rwc: rwc,
|
||||||
streams: make(map[byte]io.Writer),
|
streams: make(map[byte]io.WriteCloser),
|
||||||
}
|
}
|
||||||
|
|
||||||
go m.loop()
|
go m.loop()
|
||||||
|
@ -36,6 +36,21 @@ func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn {
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close closes the underlying io.ReadWriteCloser. This will also close
|
||||||
|
// all streams that are open.
|
||||||
|
func (m *MuxConn) Close() error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
// Close all the streams
|
||||||
|
for _, w := range m.streams {
|
||||||
|
w.Close()
|
||||||
|
}
|
||||||
|
m.streams = make(map[byte]io.WriteCloser)
|
||||||
|
|
||||||
|
return m.rwc.Close()
|
||||||
|
}
|
||||||
|
|
||||||
// Stream returns a io.ReadWriteCloser that will only read/write to the
|
// Stream returns a io.ReadWriteCloser that will only read/write to the
|
||||||
// given stream ID. No handshake is done so if the remote end does not
|
// given stream ID. No handshake is done so if the remote end does not
|
||||||
// have a stream open with the same ID, then the messages will simply
|
// have a stream open with the same ID, then the messages will simply
|
||||||
|
@ -67,6 +82,8 @@ func (m *MuxConn) Stream(id byte) (io.ReadWriteCloser, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MuxConn) loop() {
|
func (m *MuxConn) loop() {
|
||||||
|
defer m.Close()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
var id byte
|
var id byte
|
||||||
var length int32
|
var length int32
|
||||||
|
|
|
@ -17,12 +17,43 @@ func readStream(t *testing.T, s io.Reader) string {
|
||||||
return string(data[0:n])
|
return string(data[0:n])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMuxConn(t *testing.T) {
|
func testMux(t *testing.T) (client *MuxConn, server *MuxConn) {
|
||||||
l, err := net.Listen("tcp", ":0")
|
l, err := net.Listen("tcp", ":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Server side
|
||||||
|
doneCh := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(doneCh)
|
||||||
|
conn, err := l.Accept()
|
||||||
|
l.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
server = NewMuxConn(conn)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Client side
|
||||||
|
conn, err := net.Dial("tcp", l.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
client = NewMuxConn(conn)
|
||||||
|
|
||||||
|
// Wait for the server
|
||||||
|
<-doneCh
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMuxConn(t *testing.T) {
|
||||||
|
client, server := testMux(t)
|
||||||
|
defer client.Close()
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
// When the server is done
|
// When the server is done
|
||||||
doneCh := make(chan struct{})
|
doneCh := make(chan struct{})
|
||||||
readyCh := make(chan struct{})
|
readyCh := make(chan struct{})
|
||||||
|
@ -30,20 +61,13 @@ func TestMuxConn(t *testing.T) {
|
||||||
// The server side
|
// The server side
|
||||||
go func() {
|
go func() {
|
||||||
defer close(doneCh)
|
defer close(doneCh)
|
||||||
conn, err := l.Accept()
|
|
||||||
l.Close()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("err: %s", err)
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
mux := NewMuxConn(conn)
|
s0, err := server.Stream(0)
|
||||||
s0, err := mux.Stream(0)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s1, err := mux.Stream(1)
|
s1, err := server.Stream(1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -72,20 +96,12 @@ func TestMuxConn(t *testing.T) {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Client side
|
s0, err := client.Stream(0)
|
||||||
conn, err := net.Dial("tcp", l.Addr().String())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("err: %s", err)
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
mux := NewMuxConn(conn)
|
|
||||||
s0, err := mux.Stream(0)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s1, err := mux.Stream(1)
|
s1, err := client.Stream(1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -103,3 +119,47 @@ func TestMuxConn(t *testing.T) {
|
||||||
// Wait for the server to be done
|
// Wait for the server to be done
|
||||||
<-doneCh
|
<-doneCh
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMuxConn_clientClosesStreams(t *testing.T) {
|
||||||
|
client, server := testMux(t)
|
||||||
|
defer client.Close()
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
s0, err := client.Stream(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := client.Close(); err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This should block forever since we never write onto this stream.
|
||||||
|
var data [1024]byte
|
||||||
|
_, err = s0.Read(data[:])
|
||||||
|
if err != io.EOF {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMuxConn_serverClosesStreams(t *testing.T) {
|
||||||
|
client, server := testMux(t)
|
||||||
|
defer client.Close()
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
s0, err := client.Stream(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := server.Close(); err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This should block forever since we never write onto this stream.
|
||||||
|
var data [1024]byte
|
||||||
|
_, err = s0.Read(data[:])
|
||||||
|
if err != io.EOF {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue