packer/plugin: detect invalid versions

This commit is contained in:
Mitchell Hashimoto 2013-09-22 09:51:14 -07:00
parent 6965af291b
commit c7b10cb2cf
5 changed files with 46 additions and 8 deletions

View File

@ -15,6 +15,7 @@ IMPROVEMENTS:
* core: User variables can now be used for integer, boolean, etc. * core: User variables can now be used for integer, boolean, etc.
values. [GH-418] values. [GH-418]
* core: Plugins made with incompatible versions will no longer load.
* builder/amazon/all: Interrupts work while waiting for AMI to be ready. * builder/amazon/all: Interrupts work while waiting for AMI to be ready.
* provisioner/shell: Script line-endings are automatically converted to * provisioner/shell: Script line-endings are automatically converted to
Unix-style line-endings. Can be disabled by setting "binary" to "true". Unix-style line-endings. Can be disabled by setting "binary" to "true".

View File

@ -317,10 +317,24 @@ func (c *Client) Start() (address string, err error) {
err = errors.New("timeout while waiting for plugin to start") err = errors.New("timeout while waiting for plugin to start")
case <-exitCh: case <-exitCh:
err = errors.New("plugin exited before we could connect") err = errors.New("plugin exited before we could connect")
case line := <-linesCh: case lineBytes := <-linesCh:
// Trim the address and reset the err since we were able // Trim the line and split by "|" in order to get the parts of
// to read some sort of address. // the output.
c.address = strings.TrimSpace(string(line)) line := strings.TrimSpace(string(lineBytes))
parts := strings.SplitN(line, "|", 2)
if len(parts) < 2 {
err = fmt.Errorf("Unrecognized remote plugin message: %s", line)
return
}
// Test the API version
if parts[0] != APIVersion {
err = fmt.Errorf("Incompatible API version with plugin. "+
"Plugin version: %s, Ours: %s", parts[0], APIVersion)
return
}
c.address = parts[1]
address = c.address address = c.address
} }

View File

@ -37,6 +37,21 @@ func TestClient(t *testing.T) {
} }
} }
func TestClientStart_badVersion(t *testing.T) {
config := &ClientConfig{
Cmd: helperProcess("bad-version"),
StartTimeout: 50 * time.Millisecond,
}
c := NewClient(config)
defer c.Kill()
_, err := c.Start()
if err == nil {
t.Fatal("err should not be nil")
}
}
func TestClient_Start_Timeout(t *testing.T) { func TestClient_Start_Timeout(t *testing.T) {
config := &ClientConfig{ config := &ClientConfig{
Cmd: helperProcess("start-timeout"), Cmd: helperProcess("start-timeout"),

View File

@ -30,6 +30,11 @@ var Interrupts int32 = 0
const MagicCookieKey = "PACKER_PLUGIN_MAGIC_COOKIE" const MagicCookieKey = "PACKER_PLUGIN_MAGIC_COOKIE"
const MagicCookieValue = "d602bf8f470bc67ca7faa0386276bbdd4330efaf76d1a219cb4d6991ca9872b2" const MagicCookieValue = "d602bf8f470bc67ca7faa0386276bbdd4330efaf76d1a219cb4d6991ca9872b2"
// The APIVersion is outputted along with the RPC address. The plugin
// client validates this API version and will show an error if it doesn't
// know how to speak it.
const APIVersion = "1"
// This serves a single RPC connection on the given RPC server on // This serves a single RPC connection on the given RPC server on
// a random port. // a random port.
func serve(server *rpc.Server) (err error) { func serve(server *rpc.Server) (err error) {
@ -77,7 +82,7 @@ func serve(server *rpc.Server) (err error) {
// Output the address to stdout // Output the address to stdout
log.Printf("Plugin address: %s\n", address) log.Printf("Plugin address: %s\n", address)
fmt.Println(address) fmt.Printf("%s|%s\n", APIVersion, address)
os.Stdout.Sync() os.Stdout.Sync()
// Accept a connection // Accept a connection

View File

@ -50,6 +50,9 @@ func TestHelperProcess(*testing.T) {
cmd, args := args[0], args[1:] cmd, args := args[0], args[1:]
switch cmd { switch cmd {
case "bad-version":
fmt.Printf("%s1|:1234\n", APIVersion)
<-make(chan int)
case "builder": case "builder":
ServeBuilder(new(helperBuilder)) ServeBuilder(new(helperBuilder))
case "command": case "command":
@ -59,7 +62,7 @@ func TestHelperProcess(*testing.T) {
case "invalid-rpc-address": case "invalid-rpc-address":
fmt.Println("lolinvalid") fmt.Println("lolinvalid")
case "mock": case "mock":
fmt.Println(":1234") fmt.Printf("%s|:1234\n", APIVersion)
<-make(chan int) <-make(chan int)
case "post-processor": case "post-processor":
ServePostProcessor(new(helperPostProcessor)) ServePostProcessor(new(helperPostProcessor))
@ -69,11 +72,11 @@ func TestHelperProcess(*testing.T) {
time.Sleep(1 * time.Minute) time.Sleep(1 * time.Minute)
os.Exit(1) os.Exit(1)
case "stderr": case "stderr":
fmt.Println(":1234") fmt.Printf("%s|:1234\n", APIVersion)
log.Println("HELLO") log.Println("HELLO")
log.Println("WORLD") log.Println("WORLD")
case "stdin": case "stdin":
fmt.Println(":1234") fmt.Printf("%s|:1234\n", APIVersion)
data := make([]byte, 5) data := make([]byte, 5)
if _, err := os.Stdin.Read(data); err != nil { if _, err := os.Stdin.Read(data); err != nil {
log.Printf("stdin read error: %s", err) log.Printf("stdin read error: %s", err)