diff --git a/provisioner/file/provisioner.go b/provisioner/file/provisioner.go index 711beffa6..2e7c17190 100644 --- a/provisioner/file/provisioner.go +++ b/provisioner/file/provisioner.go @@ -30,16 +30,18 @@ func (p *Provisioner) Prepare(raws ...interface{}) error { errs := []error{} if _, err := os.Stat(p.config.Source); err != nil { - errs = append(errs, fmt.Errorf("Bad source file '%s': %s", p.config.Source, err)) + errs = append(errs, + fmt.Errorf("Bad source '%s': %s", p.config.Source, err)) } - if len(p.config.Destination) == 0 { + if p.config.Destination == "" { errs = append(errs, errors.New("Destination must be specified.")) } if len(errs) > 0 { return &packer.MultiError{errs} } + return nil } @@ -49,5 +51,7 @@ func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { if err != nil { return err } + defer f.Close() + return comm.Upload(p.config.Destination, f) } diff --git a/provisioner/file/provisioner_test.go b/provisioner/file/provisioner_test.go index 017bd3365..6fa1e031a 100644 --- a/provisioner/file/provisioner_test.go +++ b/provisioner/file/provisioner_test.go @@ -9,6 +9,12 @@ import ( "testing" ) +func testConfig() map[string]interface{} { + return map[string]interface{}{ + "destination": "something", + } +} + func TestProvisioner_Impl(t *testing.T) { var raw interface{} raw = &Provisioner{} @@ -19,10 +25,10 @@ func TestProvisioner_Impl(t *testing.T) { func TestProvisionerPrepare_InvalidSource(t *testing.T) { var p Provisioner - config := map[string]interface{}{"source": "/this/should/not/exist", "destination": "something"} + config := testConfig() + config["source"] = "/this/should/not/exist" err := p.Prepare(config) - if err == nil { t.Fatalf("should require existing file") } @@ -36,10 +42,11 @@ func TestProvisionerPrepare_ValidSource(t *testing.T) { t.Fatalf("error tempfile: %s", err) } defer os.Remove(tf.Name()) - config := map[string]interface{}{"source": tf.Name(), "destination": "something"} + + config := testConfig() + config["source"] = tf.Name() err = p.Prepare(config) - if err != nil { t.Fatalf("should allow valid file: %s", err) } @@ -47,10 +54,10 @@ func TestProvisionerPrepare_ValidSource(t *testing.T) { func TestProvisionerPrepare_EmptyDestination(t *testing.T) { var p Provisioner - config := map[string]interface{}{"source": "/this/exists"} + config := testConfig() + delete(config, "destination") err := p.Prepare(config) - if err == nil { t.Fatalf("should require destination path") } @@ -58,7 +65,7 @@ func TestProvisionerPrepare_EmptyDestination(t *testing.T) { type stubUploadCommunicator struct { dest string - data io.Reader + data []byte } func (suc *stubUploadCommunicator) Download(src string, data io.Writer) error { @@ -66,9 +73,10 @@ func (suc *stubUploadCommunicator) Download(src string, data io.Writer) error { } func (suc *stubUploadCommunicator) Upload(dest string, data io.Reader) error { + var err error suc.dest = dest - suc.data = data - return nil + suc.data, err = ioutil.ReadAll(data) + return err } func (suc *stubUploadCommunicator) Start(cmd *packer.RemoteCmd) error { @@ -100,11 +108,19 @@ func TestProvisionerProvision_SendsFile(t *testing.T) { t.Fatalf("error tempfile: %s", err) } defer os.Remove(tf.Name()) + if _, err = tf.Write([]byte("hello")); err != nil { t.Fatalf("error writing tempfile: %s", err) } - config := map[string]interface{}{"source": tf.Name(), "destination": "something"} - p.Prepare(config) + + config := map[string]interface{}{ + "source": tf.Name(), + "destination": "something", + } + + if err := p.Prepare(config); err != nil { + t.Fatalf("err: %s", err) + } ui := &stubUi{} comm := &stubUploadCommunicator{} @@ -112,17 +128,20 @@ func TestProvisionerProvision_SendsFile(t *testing.T) { if err != nil { t.Fatalf("should successfully provision: %s", err) } + if !strings.Contains(ui.sayMessages, tf.Name()) { t.Fatalf("should print source filename") } + if !strings.Contains(ui.sayMessages, "something") { t.Fatalf("should print destination filename") } + if comm.dest != "something" { t.Fatalf("should upload to configured destination") } - read, err := ioutil.ReadAll(comm.data) - if err != nil || string(read) != "hello" { + + if string(comm.data) != "hello" { t.Fatalf("should upload with source file's data") } }