From fe46093bcf354295ca77150fb612d87d1219f6cf Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Sun, 8 Dec 2013 18:20:27 -0800 Subject: [PATCH] packer/rpc: a muxconn... --- packer/rpc/muxconn.go | 132 +++++++++++++++++++++++++++++++++++++ packer/rpc/muxconn_test.go | 105 +++++++++++++++++++++++++++++ 2 files changed, 237 insertions(+) create mode 100644 packer/rpc/muxconn.go create mode 100644 packer/rpc/muxconn_test.go diff --git a/packer/rpc/muxconn.go b/packer/rpc/muxconn.go new file mode 100644 index 000000000..b1f1df2e5 --- /dev/null +++ b/packer/rpc/muxconn.go @@ -0,0 +1,132 @@ +package rpc + +import ( + "encoding/binary" + "fmt" + "io" + "log" + "sync" +) + +// MuxConn is a connection that can be used bi-directionally for RPC. Normally, +// Go RPC only allows client-to-server connections. This allows the client +// to actually act as a server as well. +// +// MuxConn works using a fairly dumb multiplexing technique of simply +// prefixing each message with whether it is on stream 0 (the original) +// or stream 1 (the client "server"). +// +// This can likely be abstracted to N streams, but by choosing only two +// we decided to cut a lot of corners and make this easily usable for Packer. +type MuxConn struct { + rwc io.ReadWriteCloser + streams map[byte]io.Writer + mu sync.RWMutex + wlock sync.Mutex +} + +func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn { + m := &MuxConn{ + rwc: rwc, + streams: make(map[byte]io.Writer), + } + + go m.loop() + + return m +} + +// Stream returns a io.ReadWriteCloser that will only read/write to the +// given stream ID. No handshake is done so if the remote end does not +// have a stream open with the same ID, then the messages will simply +// be dropped. +// +// This is one of those cases where we cut corners. Since Packer only does +// local connections, we can assume that both ends are ready at a certain +// point. In a real muxer, we'd probably want a handshake here. +func (m *MuxConn) Stream(id byte) (io.ReadWriteCloser, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if _, ok := m.streams[id]; ok { + return nil, fmt.Errorf("Stream %d already exists", id) + } + + // Create the stream object and channel where data will be sent to + dataR, dataW := io.Pipe() + stream := &Stream{ + id: id, + mux: m, + reader: dataR, + } + + // Set the data channel so we can write to it. + m.streams[id] = dataW + + return stream, nil +} + +func (m *MuxConn) loop() { + for { + var id byte + var length int32 + + if err := binary.Read(m.rwc, binary.BigEndian, &id); err != nil { + log.Printf("[ERR] Error reading stream ID: %s", err) + return + } + if err := binary.Read(m.rwc, binary.BigEndian, &length); err != nil { + log.Printf("[ERR] Error reading length: %s", err) + return + } + + // TODO(mitchellh): probably would be better to re-use a buffer... + data := make([]byte, length) + if _, err := m.rwc.Read(data); err != nil { + log.Printf("[ERR] Error reading data: %s", err) + return + } + + m.mu.RLock() + w, ok := m.streams[id] + if ok { + // Note that if this blocks, it'll block the whole read loop. + // Danger here... not sure how to handle it though. + w.Write(data) + } + m.mu.RUnlock() + } +} + +func (m *MuxConn) write(id byte, p []byte) (int, error) { + m.wlock.Lock() + defer m.wlock.Unlock() + + if err := binary.Write(m.rwc, binary.BigEndian, id); err != nil { + return 0, err + } + if err := binary.Write(m.rwc, binary.BigEndian, int32(len(p))); err != nil { + return 0, err + } + return m.rwc.Write(p) +} + +// Stream is a single stream of data and implements io.ReadWriteCloser +type Stream struct { + id byte + mux *MuxConn + reader io.Reader +} + +func (s *Stream) Close() error { + // Not functional yet, does it ever have to be? + return nil +} + +func (s *Stream) Read(p []byte) (int, error) { + return s.reader.Read(p) +} + +func (s *Stream) Write(p []byte) (int, error) { + return s.mux.write(s.id, p) +} diff --git a/packer/rpc/muxconn_test.go b/packer/rpc/muxconn_test.go new file mode 100644 index 000000000..4493e1c5f --- /dev/null +++ b/packer/rpc/muxconn_test.go @@ -0,0 +1,105 @@ +package rpc + +import ( + "io" + "net" + "sync" + "testing" +) + +func readStream(t *testing.T, s io.Reader) string { + var data [1024]byte + n, err := s.Read(data[:]) + if err != nil { + t.Fatalf("err: %s", err) + } + + return string(data[0:n]) +} + +func TestMuxConn(t *testing.T) { + l, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("err: %s", err) + } + + // When the server is done + doneCh := make(chan struct{}) + readyCh := make(chan struct{}) + + // The server side + go func() { + defer close(doneCh) + conn, err := l.Accept() + l.Close() + if err != nil { + t.Fatalf("err: %s", err) + } + defer conn.Close() + + mux := NewMuxConn(conn) + s0, err := mux.Stream(0) + if err != nil { + t.Fatalf("err: %s", err) + } + + s1, err := mux.Stream(1) + if err != nil { + t.Fatalf("err: %s", err) + } + + close(readyCh) + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + data := readStream(t, s1) + if data != "another" { + t.Fatalf("bad: %#v", data) + } + }() + + go func() { + defer wg.Done() + data := readStream(t, s0) + if data != "hello" { + t.Fatalf("bad: %#v", data) + } + }() + + wg.Wait() + }() + + // Client side + conn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatalf("err: %s", err) + } + defer conn.Close() + + mux := NewMuxConn(conn) + s0, err := mux.Stream(0) + if err != nil { + t.Fatalf("err: %s", err) + } + + s1, err := mux.Stream(1) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Wait for the server to be ready + <-readyCh + + if _, err := s0.Write([]byte("hello")); err != nil { + t.Fatalf("err: %s", err) + } + if _, err := s1.Write([]byte("another")); err != nil { + t.Fatalf("err: %s", err) + } + + // Wait for the server to be done + <-doneCh +}