diff --git a/provisioner/shell/provisioner.go b/provisioner/shell/provisioner.go index f66128bf8..cd447f1d8 100644 --- a/provisioner/shell/provisioner.go +++ b/provisioner/shell/provisioner.go @@ -5,6 +5,7 @@ package shell import ( "bytes" "fmt" + "errors" "github.com/mitchellh/iochan" "github.com/mitchellh/mapstructure" "github.com/mitchellh/packer/packer" @@ -18,6 +19,10 @@ import ( const DefaultRemotePath = "/tmp/script.sh" type config struct { + // An inline script to execute. Multiple strings are all executed + // in the context of a single shell. + Inline []string + // The local path of the shell script to upload and execute. Path string @@ -49,10 +54,32 @@ func (p *Provisioner) Prepare(raws ...interface{}) error { p.config.ExecuteCommand = "chmod +x {{.Path}} && {{.Path}}" } + if p.config.Inline != nil && len(p.config.Inline) == 0 { + p.config.Inline = nil + } + if p.config.RemotePath == "" { p.config.RemotePath = DefaultRemotePath } + errs := make([]error, 0) + + if p.config.Path == "" && p.config.Inline == nil { + errs = append(errs, errors.New("Either a path or inline script must be specified.")) + } else if p.config.Path != "" && p.config.Inline != nil { + errs = append(errs, errors.New("Only a path or an inline script can be specified, not both.")) + } + + if p.config.Path != "" { + if _, err := os.Stat(p.config.Path); err != nil { + errs = append(errs, fmt.Errorf("Bad script path: %s", err)) + } + } + + if len(errs) > 0 { + return &packer.MultiError{errs} + } + return nil } diff --git a/provisioner/shell/provisioner_test.go b/provisioner/shell/provisioner_test.go index 5da089187..a7ade6d5a 100644 --- a/provisioner/shell/provisioner_test.go +++ b/provisioner/shell/provisioner_test.go @@ -2,9 +2,17 @@ package shell import ( "github.com/mitchellh/packer/packer" + "io/ioutil" + "os" "testing" ) +func testConfig() map[string]interface{} { + return map[string]interface{}{ + "inline": []interface{}{"foo", "bar"}, + } +} + func TestProvisioner_Impl(t *testing.T) { var raw interface{} raw = &Provisioner{} @@ -14,10 +22,10 @@ func TestProvisioner_Impl(t *testing.T) { } func TestProvisionerPrepare_Defaults(t *testing.T) { - raw := map[string]interface{}{} + var p Provisioner + config := testConfig() - p := &Provisioner{} - err := p.Prepare(raw) + err := p.Prepare(config) if err != nil { t.Fatalf("err: %s", err) } @@ -26,3 +34,54 @@ func TestProvisionerPrepare_Defaults(t *testing.T) { t.Errorf("unexpected remote path: %s", p.config.RemotePath) } } + +func TestProvisionerPrepare_Path(t *testing.T) { + var p Provisioner + config := testConfig() + delete(config, "inline") + + config["path"] = "/this/should/not/exist" + err := p.Prepare(config) + if err == nil { + t.Fatal("should have error") + } + + // Test with a good one + tf, err := ioutil.TempFile("", "packer") + if err != nil { + t.Fatalf("error tempfile: %s", err) + } + defer os.Remove(tf.Name()) + + config["path"] = tf.Name() + err = p.Prepare(config) + if err != nil { + t.Fatalf("should not have error: %s", err) + } +} + +func TestProvisionerPrepare_PathAndInline(t *testing.T) { + var p Provisioner + config := testConfig() + + delete(config, "inline") + delete(config, "path") + err := p.Prepare(config) + if err == nil { + t.Fatal("should have error") + } + + // Test with both + tf, err := ioutil.TempFile("", "packer") + if err != nil { + t.Fatalf("error tempfile: %s", err) + } + defer os.Remove(tf.Name()) + + config["inline"] = []interface{}{"foo"} + config["path"] = tf.Name() + err = p.Prepare(config) + if err == nil { + t.Fatal("should have error") + } +}