packer/rpc: close the streams when the underlying rwc closes
This commit is contained in:
parent
fe46093bcf
commit
5c6831080c
|
@ -13,14 +13,14 @@ import (
|
|||
// to actually act as a server as well.
|
||||
//
|
||||
// MuxConn works using a fairly dumb multiplexing technique of simply
|
||||
// prefixing each message with whether it is on stream 0 (the original)
|
||||
// or stream 1 (the client "server").
|
||||
// prefixing each message with what stream it is on along with the length
|
||||
// of the data.
|
||||
//
|
||||
// This can likely be abstracted to N streams, but by choosing only two
|
||||
// we decided to cut a lot of corners and make this easily usable for Packer.
|
||||
type MuxConn struct {
|
||||
rwc io.ReadWriteCloser
|
||||
streams map[byte]io.Writer
|
||||
streams map[byte]io.WriteCloser
|
||||
mu sync.RWMutex
|
||||
wlock sync.Mutex
|
||||
}
|
||||
|
@ -28,7 +28,7 @@ type MuxConn struct {
|
|||
func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn {
|
||||
m := &MuxConn{
|
||||
rwc: rwc,
|
||||
streams: make(map[byte]io.Writer),
|
||||
streams: make(map[byte]io.WriteCloser),
|
||||
}
|
||||
|
||||
go m.loop()
|
||||
|
@ -36,6 +36,21 @@ func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn {
|
|||
return m
|
||||
}
|
||||
|
||||
// Close closes the underlying io.ReadWriteCloser. This will also close
|
||||
// all streams that are open.
|
||||
func (m *MuxConn) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Close all the streams
|
||||
for _, w := range m.streams {
|
||||
w.Close()
|
||||
}
|
||||
m.streams = make(map[byte]io.WriteCloser)
|
||||
|
||||
return m.rwc.Close()
|
||||
}
|
||||
|
||||
// Stream returns a io.ReadWriteCloser that will only read/write to the
|
||||
// given stream ID. No handshake is done so if the remote end does not
|
||||
// have a stream open with the same ID, then the messages will simply
|
||||
|
@ -67,6 +82,8 @@ func (m *MuxConn) Stream(id byte) (io.ReadWriteCloser, error) {
|
|||
}
|
||||
|
||||
func (m *MuxConn) loop() {
|
||||
defer m.Close()
|
||||
|
||||
for {
|
||||
var id byte
|
||||
var length int32
|
||||
|
|
|
@ -17,12 +17,43 @@ func readStream(t *testing.T, s io.Reader) string {
|
|||
return string(data[0:n])
|
||||
}
|
||||
|
||||
func TestMuxConn(t *testing.T) {
|
||||
func testMux(t *testing.T) (client *MuxConn, server *MuxConn) {
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Server side
|
||||
doneCh := make(chan struct{})
|
||||
go func() {
|
||||
defer close(doneCh)
|
||||
conn, err := l.Accept()
|
||||
l.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
server = NewMuxConn(conn)
|
||||
}()
|
||||
|
||||
// Client side
|
||||
conn, err := net.Dial("tcp", l.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
client = NewMuxConn(conn)
|
||||
|
||||
// Wait for the server
|
||||
<-doneCh
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func TestMuxConn(t *testing.T) {
|
||||
client, server := testMux(t)
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
// When the server is done
|
||||
doneCh := make(chan struct{})
|
||||
readyCh := make(chan struct{})
|
||||
|
@ -30,20 +61,13 @@ func TestMuxConn(t *testing.T) {
|
|||
// The server side
|
||||
go func() {
|
||||
defer close(doneCh)
|
||||
conn, err := l.Accept()
|
||||
l.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
mux := NewMuxConn(conn)
|
||||
s0, err := mux.Stream(0)
|
||||
s0, err := server.Stream(0)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
s1, err := mux.Stream(1)
|
||||
s1, err := server.Stream(1)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -72,20 +96,12 @@ func TestMuxConn(t *testing.T) {
|
|||
wg.Wait()
|
||||
}()
|
||||
|
||||
// Client side
|
||||
conn, err := net.Dial("tcp", l.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
mux := NewMuxConn(conn)
|
||||
s0, err := mux.Stream(0)
|
||||
s0, err := client.Stream(0)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
s1, err := mux.Stream(1)
|
||||
s1, err := client.Stream(1)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -103,3 +119,47 @@ func TestMuxConn(t *testing.T) {
|
|||
// Wait for the server to be done
|
||||
<-doneCh
|
||||
}
|
||||
|
||||
func TestMuxConn_clientClosesStreams(t *testing.T) {
|
||||
client, server := testMux(t)
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
s0, err := client.Stream(0)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err := client.Close(); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// This should block forever since we never write onto this stream.
|
||||
var data [1024]byte
|
||||
_, err = s0.Read(data[:])
|
||||
if err != io.EOF {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMuxConn_serverClosesStreams(t *testing.T) {
|
||||
client, server := testMux(t)
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
s0, err := client.Stream(0)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err := server.Close(); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// This should block forever since we never write onto this stream.
|
||||
var data [1024]byte
|
||||
_, err = s0.Read(data[:])
|
||||
if err != io.EOF {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue