packer/rpc: a muxconn...
This commit is contained in:
parent
a66f148ede
commit
fe46093bcf
|
@ -0,0 +1,132 @@
|
||||||
|
package rpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MuxConn is a connection that can be used bi-directionally for RPC. Normally,
|
||||||
|
// Go RPC only allows client-to-server connections. This allows the client
|
||||||
|
// to actually act as a server as well.
|
||||||
|
//
|
||||||
|
// MuxConn works using a fairly dumb multiplexing technique of simply
|
||||||
|
// prefixing each message with whether it is on stream 0 (the original)
|
||||||
|
// or stream 1 (the client "server").
|
||||||
|
//
|
||||||
|
// This can likely be abstracted to N streams, but by choosing only two
|
||||||
|
// we decided to cut a lot of corners and make this easily usable for Packer.
|
||||||
|
type MuxConn struct {
|
||||||
|
rwc io.ReadWriteCloser
|
||||||
|
streams map[byte]io.Writer
|
||||||
|
mu sync.RWMutex
|
||||||
|
wlock sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn {
|
||||||
|
m := &MuxConn{
|
||||||
|
rwc: rwc,
|
||||||
|
streams: make(map[byte]io.Writer),
|
||||||
|
}
|
||||||
|
|
||||||
|
go m.loop()
|
||||||
|
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream returns a io.ReadWriteCloser that will only read/write to the
|
||||||
|
// given stream ID. No handshake is done so if the remote end does not
|
||||||
|
// have a stream open with the same ID, then the messages will simply
|
||||||
|
// be dropped.
|
||||||
|
//
|
||||||
|
// This is one of those cases where we cut corners. Since Packer only does
|
||||||
|
// local connections, we can assume that both ends are ready at a certain
|
||||||
|
// point. In a real muxer, we'd probably want a handshake here.
|
||||||
|
func (m *MuxConn) Stream(id byte) (io.ReadWriteCloser, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if _, ok := m.streams[id]; ok {
|
||||||
|
return nil, fmt.Errorf("Stream %d already exists", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the stream object and channel where data will be sent to
|
||||||
|
dataR, dataW := io.Pipe()
|
||||||
|
stream := &Stream{
|
||||||
|
id: id,
|
||||||
|
mux: m,
|
||||||
|
reader: dataR,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the data channel so we can write to it.
|
||||||
|
m.streams[id] = dataW
|
||||||
|
|
||||||
|
return stream, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MuxConn) loop() {
|
||||||
|
for {
|
||||||
|
var id byte
|
||||||
|
var length int32
|
||||||
|
|
||||||
|
if err := binary.Read(m.rwc, binary.BigEndian, &id); err != nil {
|
||||||
|
log.Printf("[ERR] Error reading stream ID: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := binary.Read(m.rwc, binary.BigEndian, &length); err != nil {
|
||||||
|
log.Printf("[ERR] Error reading length: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(mitchellh): probably would be better to re-use a buffer...
|
||||||
|
data := make([]byte, length)
|
||||||
|
if _, err := m.rwc.Read(data); err != nil {
|
||||||
|
log.Printf("[ERR] Error reading data: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.RLock()
|
||||||
|
w, ok := m.streams[id]
|
||||||
|
if ok {
|
||||||
|
// Note that if this blocks, it'll block the whole read loop.
|
||||||
|
// Danger here... not sure how to handle it though.
|
||||||
|
w.Write(data)
|
||||||
|
}
|
||||||
|
m.mu.RUnlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MuxConn) write(id byte, p []byte) (int, error) {
|
||||||
|
m.wlock.Lock()
|
||||||
|
defer m.wlock.Unlock()
|
||||||
|
|
||||||
|
if err := binary.Write(m.rwc, binary.BigEndian, id); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if err := binary.Write(m.rwc, binary.BigEndian, int32(len(p))); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return m.rwc.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream is a single stream of data and implements io.ReadWriteCloser
|
||||||
|
type Stream struct {
|
||||||
|
id byte
|
||||||
|
mux *MuxConn
|
||||||
|
reader io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stream) Close() error {
|
||||||
|
// Not functional yet, does it ever have to be?
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stream) Read(p []byte) (int, error) {
|
||||||
|
return s.reader.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stream) Write(p []byte) (int, error) {
|
||||||
|
return s.mux.write(s.id, p)
|
||||||
|
}
|
|
@ -0,0 +1,105 @@
|
||||||
|
package rpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func readStream(t *testing.T, s io.Reader) string {
|
||||||
|
var data [1024]byte
|
||||||
|
n, err := s.Read(data[:])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(data[0:n])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMuxConn(t *testing.T) {
|
||||||
|
l, err := net.Listen("tcp", ":0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// When the server is done
|
||||||
|
doneCh := make(chan struct{})
|
||||||
|
readyCh := make(chan struct{})
|
||||||
|
|
||||||
|
// The server side
|
||||||
|
go func() {
|
||||||
|
defer close(doneCh)
|
||||||
|
conn, err := l.Accept()
|
||||||
|
l.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
mux := NewMuxConn(conn)
|
||||||
|
s0, err := mux.Stream(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s1, err := mux.Stream(1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
close(readyCh)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
data := readStream(t, s1)
|
||||||
|
if data != "another" {
|
||||||
|
t.Fatalf("bad: %#v", data)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
data := readStream(t, s0)
|
||||||
|
if data != "hello" {
|
||||||
|
t.Fatalf("bad: %#v", data)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Client side
|
||||||
|
conn, err := net.Dial("tcp", l.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
mux := NewMuxConn(conn)
|
||||||
|
s0, err := mux.Stream(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s1, err := mux.Stream(1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the server to be ready
|
||||||
|
<-readyCh
|
||||||
|
|
||||||
|
if _, err := s0.Write([]byte("hello")); err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
if _, err := s1.Write([]byte("another")); err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the server to be done
|
||||||
|
<-doneCh
|
||||||
|
}
|
Loading…
Reference in New Issue