diff --git a/packer/communicator_mock.go b/packer/communicator_mock.go index e8c6e111e..b6127015f 100644 --- a/packer/communicator_mock.go +++ b/packer/communicator_mock.go @@ -1,34 +1,80 @@ package packer import ( + "bytes" "io" + "sync" ) // MockCommunicator is a valid Communicator implementation that can be // used for tests. type MockCommunicator struct { - Stderr io.Reader - Stdout io.Reader + StartCalled bool + StartCmd *RemoteCmd + StartStderr string + StartStdout string + StartStdin string + StartExitStatus int + + UploadCalled bool + UploadPath string + UploadData string + + DownloadCalled bool + DownloadPath string + DownloadData string } func (c *MockCommunicator) Start(rc *RemoteCmd) error { + c.StartCalled = true + c.StartCmd = rc + go func() { - rc.Lock() - defer rc.Unlock() - - if rc.Stdout != nil && c.Stdout != nil { - io.Copy(rc.Stdout, c.Stdout) + var wg sync.WaitGroup + if rc.Stdout != nil && c.StartStdout != "" { + wg.Add(1) + go func() { + rc.Stdout.Write([]byte(c.StartStdout)) + wg.Done() + }() } - if rc.Stderr != nil && c.Stderr != nil { - io.Copy(rc.Stderr, c.Stderr) + if rc.Stderr != nil && c.StartStderr != "" { + wg.Add(1) + go func() { + rc.Stderr.Write([]byte(c.StartStderr)) + wg.Done() + }() } + + if rc.Stdin != nil { + wg.Add(1) + go func() { + defer wg.Done() + var data bytes.Buffer + io.Copy(&data, rc.Stdin) + c.StartStdin = data.String() + }() + } + + wg.Wait() + rc.SetExited(c.StartExitStatus) }() return nil } -func (c *MockCommunicator) Upload(string, io.Reader) error { +func (c *MockCommunicator) Upload(path string, r io.Reader) error { + c.UploadCalled = true + c.UploadPath = path + + var data bytes.Buffer + if _, err := io.Copy(&data, r); err != nil { + panic(err) + } + + c.UploadData = data.String() + return nil } @@ -36,6 +82,10 @@ func (c *MockCommunicator) UploadDir(string, string, []string) error { return nil } -func (c *MockCommunicator) Download(string, io.Writer) error { +func (c *MockCommunicator) Download(path string, w io.Writer) error { + c.DownloadCalled = true + c.DownloadPath = path + w.Write([]byte(c.DownloadData)) + return nil } diff --git a/packer/communicator_test.go b/packer/communicator_test.go index 432f60e8e..7c88e2059 100644 --- a/packer/communicator_test.go +++ b/packer/communicator_test.go @@ -11,14 +11,10 @@ func TestRemoteCmd_StartWithUi(t *testing.T) { data := "hello\nworld\nthere" originalOutput := new(bytes.Buffer) - rcOutput := new(bytes.Buffer) uiOutput := new(bytes.Buffer) - rcOutput.WriteString(data) - - testComm := &MockCommunicator{ - Stdout: rcOutput, - } + testComm := new(MockCommunicator) + testComm.StartStdout = data testUi := &BasicUi{ Reader: new(bytes.Buffer), Writer: uiOutput, @@ -29,22 +25,20 @@ func TestRemoteCmd_StartWithUi(t *testing.T) { Stdout: originalOutput, } - go func() { - time.Sleep(100 * time.Millisecond) - rc.SetExited(0) - }() - err := rc.StartWithUi(testComm, testUi) if err != nil { t.Fatalf("err: %s", err) } - if uiOutput.String() != strings.TrimSpace(data)+"\n" { + rc.Wait() + + expected := strings.TrimSpace(data) + if uiOutput.String() != expected+"\n" { t.Fatalf("bad output: '%s'", uiOutput.String()) } - if originalOutput.String() != data { - t.Fatalf("original is bad: '%s'", originalOutput.String()) + if originalOutput.String() != expected { + t.Fatalf("bad: %#v", originalOutput.String()) } } diff --git a/packer/rpc/communicator.go b/packer/rpc/communicator.go index dc2e570a3..f5f99c8e9 100644 --- a/packer/rpc/communicator.go +++ b/packer/rpc/communicator.go @@ -123,6 +123,10 @@ func (c *communicator) Upload(path string, r io.Reader) (err error) { return } +func (c *communicator) UploadDir(dst string, src string, exclude []string) error { + return nil +} + func (c *communicator) Download(path string, w io.Writer) (err error) { // We need to create a server that can proxy that data downloaded // into the writer because we can't gob encode a writer directly. diff --git a/packer/rpc/communicator_test.go b/packer/rpc/communicator_test.go index b54ad0115..cf202ad8b 100644 --- a/packer/rpc/communicator_test.go +++ b/packer/rpc/communicator_test.go @@ -2,52 +2,15 @@ package rpc import ( "bufio" - "cgl.tideland.biz/asserts" "github.com/mitchellh/packer/packer" "io" "net/rpc" "testing" - "time" ) -type testCommunicator struct { - startCalled bool - startCmd *packer.RemoteCmd - - uploadCalled bool - uploadPath string - uploadData string - - downloadCalled bool - downloadPath string -} - -func (t *testCommunicator) Start(cmd *packer.RemoteCmd) error { - t.startCalled = true - t.startCmd = cmd - return nil -} - -func (t *testCommunicator) Upload(path string, reader io.Reader) (err error) { - t.uploadCalled = true - t.uploadPath = path - t.uploadData, err = bufio.NewReader(reader).ReadString('\n') - return -} - -func (t *testCommunicator) Download(path string, writer io.Writer) error { - t.downloadCalled = true - t.downloadPath = path - writer.Write([]byte("download\n")) - - return nil -} - func TestCommunicatorRPC(t *testing.T) { - assert := asserts.NewTestingAsserts(t, true) - // Create the interface to test - c := new(testCommunicator) + c := new(packer.MockCommunicator) // Start the server server := rpc.NewServer() @@ -56,7 +19,9 @@ func TestCommunicatorRPC(t *testing.T) { // Create the client over RPC and run some methods to verify it works client, err := rpc.Dial("tcp", address) - assert.Nil(err, "should be able to connect") + if err != nil { + t.Fatalf("err: %s", err) + } remote := Communicator(client) // The remote command we'll use @@ -70,56 +35,74 @@ func TestCommunicatorRPC(t *testing.T) { cmd.Stdout = stdout_w cmd.Stderr = stderr_w + // Send some data on stdout and stderr from the mock + c.StartStdout = "outfoo\n" + c.StartStderr = "errfoo\n" + c.StartExitStatus = 42 + // Test Start err = remote.Start(&cmd) - assert.Nil(err, "should not have an error") - - // Test that we can read from stdout - c.startCmd.Stdout.Write([]byte("outfoo\n")) - bufOut := bufio.NewReader(stdout_r) - data, err := bufOut.ReadString('\n') - assert.Nil(err, "should have no problem reading stdout") - assert.Equal(data, "outfoo\n", "should be correct stdout") - - // Test that we can read from stderr - c.startCmd.Stderr.Write([]byte("errfoo\n")) - bufErr := bufio.NewReader(stderr_r) - data, err = bufErr.ReadString('\n') - assert.Nil(err, "should have no problem reading stderr") - assert.Equal(data, "errfoo\n", "should be correct stderr") - - // Test that we can write to stdin - stdin_w.Write([]byte("infoo\n")) - bufIn := bufio.NewReader(c.startCmd.Stdin) - data, err = bufIn.ReadString('\n') - assert.Nil(err, "should have no problem reading stdin") - assert.Equal(data, "infoo\n", "should be correct stdin") - - // Test that we can get the exit status properly - c.startCmd.SetExited(42) - - for i := 0; i < 5; i++ { - cmd.Lock() - exited := cmd.Exited - cmd.Unlock() - if exited { - assert.Equal(cmd.ExitStatus, 42, "should have proper exit status") - break - } - - time.Sleep(50 * time.Millisecond) + if err != nil { + t.Fatalf("err: %s", err) } - assert.True(cmd.Exited, "should have exited") + // Test that we can read from stdout + bufOut := bufio.NewReader(stdout_r) + data, err := bufOut.ReadString('\n') + if err != nil { + t.Fatalf("err: %s", err) + } + + if data != "outfoo\n" { + t.Fatalf("bad data: %s", data) + } + + // Test that we can read from stderr + bufErr := bufio.NewReader(stderr_r) + data, err = bufErr.ReadString('\n') + if err != nil { + t.Fatalf("err: %s", err) + } + + if data != "errfoo\n" { + t.Fatalf("bad data: %s", data) + } + + // Test that we can write to stdin + stdin_w.Write([]byte("info\n")) + stdin_w.Close() + cmd.Wait() + if c.StartStdin != "info\n" { + t.Fatalf("bad data: %s", data) + } + + // Test that we can get the exit status properly + if cmd.ExitStatus != 42 { + t.Fatalf("bad exit: %d", cmd.ExitStatus) + } // Test that we can upload things uploadR, uploadW := io.Pipe() - go uploadW.Write([]byte("uploadfoo\n")) + go func() { + defer uploadW.Close() + uploadW.Write([]byte("uploadfoo\n")) + }() err = remote.Upload("foo", uploadR) - assert.Nil(err, "should not error") - assert.True(c.uploadCalled, "should be called") - assert.Equal(c.uploadPath, "foo", "should be correct path") - assert.Equal(c.uploadData, "uploadfoo\n", "should have the proper data") + if err != nil { + t.Fatalf("err: %s", err) + } + + if !c.UploadCalled { + t.Fatal("should have uploaded") + } + + if c.UploadPath != "foo" { + t.Fatalf("path: %s", c.UploadPath) + } + + if c.UploadData != "uploadfoo\n" { + t.Fatalf("bad: %s", c.UploadData) + } // Test that we can download things downloadR, downloadW := io.Pipe() @@ -133,21 +116,34 @@ func TestCommunicatorRPC(t *testing.T) { downloadDone <- true }() + c.DownloadData = "download\n" err = remote.Download("bar", downloadW) - assert.Nil(err, "should not error") - assert.True(c.downloadCalled, "should have called download") - assert.Equal(c.downloadPath, "bar", "should have correct download path") + if err != nil { + t.Fatalf("err: %s", err) + } + + if !c.DownloadCalled { + t.Fatal("download should be called") + } + + if c.DownloadPath != "bar" { + t.Fatalf("bad: %s", c.DownloadPath) + } <-downloadDone - assert.Nil(downloadErr, "should not error reading download data") - assert.Equal(downloadData, "download\n", "should have the proper data") + if downloadErr != nil { + t.Fatalf("err: %s", downloadErr) + } + + if downloadData != "download\n" { + t.Fatalf("bad: %s", downloadData) + } } func TestCommunicator_ImplementsCommunicator(t *testing.T) { - assert := asserts.NewTestingAsserts(t, true) - - var r packer.Communicator - c := Communicator(nil) - - assert.Implementor(c, &r, "should be a Communicator") + var raw interface{} + raw = Communicator(nil) + if _, ok := raw.(packer.Communicator); !ok { + t.Fatal("should be a Communicator") + } } diff --git a/packer/rpc/provisioner_test.go b/packer/rpc/provisioner_test.go index 56b3b80c4..106ae62eb 100644 --- a/packer/rpc/provisioner_test.go +++ b/packer/rpc/provisioner_test.go @@ -52,7 +52,7 @@ func TestProvisionerRPC(t *testing.T) { // Test Provision ui := &testUi{} - comm := &testCommunicator{} + comm := new(packer.MockCommunicator) pClient.Provision(ui, comm) assert.True(p.provCalled, "provision should be called")