From a6397c160aedbd5e3c19820645e3d592b59eb5b2 Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Sun, 22 Sep 2013 09:51:14 -0700 Subject: [PATCH] packer/plugin: detect invalid versions --- CHANGELOG.md | 1 + packer/plugin/client.go | 22 ++++++++++++++++++---- packer/plugin/client_test.go | 15 +++++++++++++++ packer/plugin/plugin.go | 7 ++++++- packer/plugin/plugin_test.go | 9 ++++++--- 5 files changed, 46 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f2e696741..2bd9fc504 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ IMPROVEMENTS: * core: User variables can now be used for integer, boolean, etc. values. [GH-418] +* core: Plugins made with incompatible versions will no longer load. * builder/amazon/all: Interrupts work while waiting for AMI to be ready. * provisioner/shell: Script line-endings are automatically converted to Unix-style line-endings. Can be disabled by setting "binary" to "true". diff --git a/packer/plugin/client.go b/packer/plugin/client.go index d3a302977..792f9d99d 100644 --- a/packer/plugin/client.go +++ b/packer/plugin/client.go @@ -317,10 +317,24 @@ func (c *Client) Start() (address string, err error) { err = errors.New("timeout while waiting for plugin to start") case <-exitCh: err = errors.New("plugin exited before we could connect") - case line := <-linesCh: - // Trim the address and reset the err since we were able - // to read some sort of address. - c.address = strings.TrimSpace(string(line)) + case lineBytes := <-linesCh: + // Trim the line and split by "|" in order to get the parts of + // the output. + line := strings.TrimSpace(string(lineBytes)) + parts := strings.SplitN(line, "|", 2) + if len(parts) < 2 { + err = fmt.Errorf("Unrecognized remote plugin message: %s", line) + return + } + + // Test the API version + if parts[0] != APIVersion { + err = fmt.Errorf("Incompatible API version with plugin. "+ + "Plugin version: %s, Ours: %s", parts[0], APIVersion) + return + } + + c.address = parts[1] address = c.address } diff --git a/packer/plugin/client_test.go b/packer/plugin/client_test.go index ae71c3362..f9257034e 100644 --- a/packer/plugin/client_test.go +++ b/packer/plugin/client_test.go @@ -37,6 +37,21 @@ func TestClient(t *testing.T) { } } +func TestClientStart_badVersion(t *testing.T) { + config := &ClientConfig{ + Cmd: helperProcess("bad-version"), + StartTimeout: 50 * time.Millisecond, + } + + c := NewClient(config) + defer c.Kill() + + _, err := c.Start() + if err == nil { + t.Fatal("err should not be nil") + } +} + func TestClient_Start_Timeout(t *testing.T) { config := &ClientConfig{ Cmd: helperProcess("start-timeout"), diff --git a/packer/plugin/plugin.go b/packer/plugin/plugin.go index 9afba0285..a91fcc3ce 100644 --- a/packer/plugin/plugin.go +++ b/packer/plugin/plugin.go @@ -30,6 +30,11 @@ var Interrupts int32 = 0 const MagicCookieKey = "PACKER_PLUGIN_MAGIC_COOKIE" const MagicCookieValue = "d602bf8f470bc67ca7faa0386276bbdd4330efaf76d1a219cb4d6991ca9872b2" +// The APIVersion is outputted along with the RPC address. The plugin +// client validates this API version and will show an error if it doesn't +// know how to speak it. +const APIVersion = "1" + // This serves a single RPC connection on the given RPC server on // a random port. func serve(server *rpc.Server) (err error) { @@ -77,7 +82,7 @@ func serve(server *rpc.Server) (err error) { // Output the address to stdout log.Printf("Plugin address: %s\n", address) - fmt.Println(address) + fmt.Printf("%s|%s\n", APIVersion, address) os.Stdout.Sync() // Accept a connection diff --git a/packer/plugin/plugin_test.go b/packer/plugin/plugin_test.go index 17018f82d..10c3f9d5c 100644 --- a/packer/plugin/plugin_test.go +++ b/packer/plugin/plugin_test.go @@ -50,6 +50,9 @@ func TestHelperProcess(*testing.T) { cmd, args := args[0], args[1:] switch cmd { + case "bad-version": + fmt.Printf("%s1|:1234\n", APIVersion) + <-make(chan int) case "builder": ServeBuilder(new(helperBuilder)) case "command": @@ -59,7 +62,7 @@ func TestHelperProcess(*testing.T) { case "invalid-rpc-address": fmt.Println("lolinvalid") case "mock": - fmt.Println(":1234") + fmt.Printf("%s|:1234\n", APIVersion) <-make(chan int) case "post-processor": ServePostProcessor(new(helperPostProcessor)) @@ -69,11 +72,11 @@ func TestHelperProcess(*testing.T) { time.Sleep(1 * time.Minute) os.Exit(1) case "stderr": - fmt.Println(":1234") + fmt.Printf("%s|:1234\n", APIVersion) log.Println("HELLO") log.Println("WORLD") case "stdin": - fmt.Println(":1234") + fmt.Printf("%s|:1234\n", APIVersion) data := make([]byte, 5) if _, err := os.Stdin.Read(data); err != nil { log.Printf("stdin read error: %s", err)