diff --git a/packer/rpc/mux_broker.go b/packer/rpc/mux_broker.go new file mode 100644 index 000000000..2e1061c9d --- /dev/null +++ b/packer/rpc/mux_broker.go @@ -0,0 +1,160 @@ +package rpc + +import ( + "encoding/binary" + "fmt" + "net" + "sync" + "time" + + "github.com/hashicorp/yamux" +) + +// muxBroker is responsible for brokering multiplexed connections by unique ID. +// +// This allows a plugin to request a channel with a specific ID to connect to +// or accept a connection from, and the broker handles the details of +// holding these channels open while they're being negotiated. +type muxBroker struct { + session *yamux.Session + streams map[uint32]*muxBrokerPending + + sync.Mutex +} + +type muxBrokerPending struct { + ch chan net.Conn + doneCh chan struct{} +} + +func newMuxBroker(s *yamux.Session) *muxBroker { + return &muxBroker{ + session: s, + streams: make(map[uint32]*muxBrokerPending), + } +} + +// Accept accepts a connection by ID. +// +// This should not be called multiple times with the same ID at one time. +func (m *muxBroker) Accept(id uint32) (net.Conn, error) { + var c net.Conn + p := m.getStream(id) + select { + case c = <-p.ch: + close(p.doneCh) + case <-time.After(5 * time.Second): + m.Lock() + defer m.Unlock() + delete(m.streams, id) + + return nil, fmt.Errorf("timeout waiting for accept") + } + + // Ack our connection + if err := binary.Write(c, binary.LittleEndian, id); err != nil { + c.Close() + return nil, err + } + + return c, nil +} + +// Dial opens a connection by ID. +func (m *muxBroker) Dial(id uint32) (net.Conn, error) { + // Open the stream + stream, err := m.session.OpenStream() + if err != nil { + return nil, err + } + + // Write the stream ID onto the wire. + if err := binary.Write(stream, binary.LittleEndian, id); err != nil { + stream.Close() + return nil, err + } + + // Read the ack that we connected. Then we're off! + var ack uint32 + if err := binary.Read(stream, binary.LittleEndian, &ack); err != nil { + stream.Close() + return nil, err + } + if ack != id { + stream.Close() + return nil, fmt.Errorf("bad ack: %d (expected %d)", ack, id) + } + + return stream, nil +} + +// Run starts the brokering and should be executed in a goroutine, since it +// blocks forever, or until the session closes. +func (m *muxBroker) Run() { + for { + stream, err := m.session.AcceptStream() + if err != nil { + // Once we receive an error, just exit + break + } + + // Read the stream ID from the stream + var id uint32 + if err := binary.Read(stream, binary.LittleEndian, &id); err != nil { + stream.Close() + continue + } + + // Initialize the waiter + p := m.getStream(id) + select { + case p.ch <- stream: + default: + } + + // Wait for a timeout + go m.timeoutWait(id, p) + } +} + +func (m *muxBroker) getStream(id uint32) *muxBrokerPending { + m.Lock() + defer m.Unlock() + + p, ok := m.streams[id] + if ok { + return p + } + + m.streams[id] = &muxBrokerPending{ + ch: make(chan net.Conn, 1), + doneCh: make(chan struct{}), + } + return m.streams[id] +} + +func (m *muxBroker) timeoutWait(id uint32, p *muxBrokerPending) { + // Wait for the stream to either be picked up and connected, or + // for a timeout. + timeout := false + select { + case <-p.doneCh: + case <-time.After(5 * time.Second): + timeout = true + } + + m.Lock() + defer m.Unlock() + + // Delete the stream so no one else can grab it + delete(m.streams, id) + + // If we timed out, then check if we have a channel in the buffer, + // and if so, close it. + if timeout { + select { + case s := <-p.ch: + s.Close() + } + } +} diff --git a/packer/rpc/mux_broker_test.go b/packer/rpc/mux_broker_test.go new file mode 100644 index 000000000..88739a0ff --- /dev/null +++ b/packer/rpc/mux_broker_test.go @@ -0,0 +1,82 @@ +package rpc + +import ( + "net" + "testing" + + "github.com/hashicorp/yamux" +) + +func TestMuxBroker(t *testing.T) { + c, s := testYamux(t) + defer c.Close() + defer s.Close() + + bc := newMuxBroker(c) + bs := newMuxBroker(s) + go bc.Run() + go bs.Run() + + go func() { + c, err := bc.Dial(5) + if err != nil { + t.Fatalf("err: %s", err) + } + + if _, err := c.Write([]byte{42}); err != nil { + t.Fatalf("err: %s", err) + } + }() + + client, err := bs.Accept(5) + if err != nil { + t.Fatalf("err: %s", err) + } + + var data [1]byte + if _, err := client.Read(data[:]); err != nil { + t.Fatalf("err: %s", err) + } + + if data[0] != 42 { + t.Fatalf("bad: %d", data[0]) + } +} + +func testYamux(t *testing.T) (client *yamux.Session, server *yamux.Session) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + // Server side + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) + conn, err := l.Accept() + l.Close() + if err != nil { + t.Fatalf("err: %s", err) + } + + server, err = yamux.Server(conn, nil) + if err != nil { + t.Fatalf("err: %s", err) + } + }() + + // Client side + conn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatalf("err: %s", err) + } + client, err = yamux.Client(conn, nil) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Wait for the server + <-doneCh + + return +}