diff --git a/packer/communicator.go b/packer/communicator.go index 7db4f3391..a959d79c6 100644 --- a/packer/communicator.go +++ b/packer/communicator.go @@ -1,7 +1,9 @@ package packer import ( + "github.com/mitchellh/iochan" "io" + "strings" "time" ) @@ -57,6 +59,61 @@ type Communicator interface { Download(string, io.Writer) error } +// StartWithUi runs the remote command and streams the output to any +// configured Writers for stdout/stderr, while also writing each line +// as it comes to a Ui. +func (r *RemoteCmd) StartWithUi(c Communicator, ui Ui) error { + stdout_r, stdout_w := io.Pipe() + stderr_r, stderr_w := io.Pipe() + + // Set the writers for the output so that we get it streamed to us + r.Stdout = stdout_w + r.Stderr = stderr_w + + // Start the command + if err := c.Start(r); err != nil { + return err + } + + // Create the channels we'll use for data + exitCh := make(chan int, 1) + stdoutCh := iochan.DelimReader(stdout_r, '\n') + stderrCh := iochan.DelimReader(stderr_r, '\n') + + // Start the goroutine to watch for the exit + go func() { + defer stdout_w.Close() + defer stderr_w.Close() + r.Wait() + exitCh <- r.ExitStatus + }() + + // Loop and get all our output +OutputLoop: + for { + select { + case output := <-stderrCh: + ui.Message(strings.TrimSpace(output)) + case output := <-stdoutCh: + ui.Message(strings.TrimSpace(output)) + case <-exitCh: + break OutputLoop + } + } + + // Make sure we finish off stdout/stderr because we may have gotten + // a message from the exit channel before finishing these first. + for output := range stdoutCh { + ui.Message(strings.TrimSpace(output)) + } + + for output := range stderrCh { + ui.Message(strings.TrimSpace(output)) + } + + return nil +} + // Wait waits for the remote command to complete. func (r *RemoteCmd) Wait() { for !r.Exited { diff --git a/packer/communicator_test.go b/packer/communicator_test.go index 2c4e9e33b..484b508e9 100644 --- a/packer/communicator_test.go +++ b/packer/communicator_test.go @@ -1,10 +1,75 @@ package packer import ( + "bytes" + "io" + "strings" "testing" "time" ) +type TestCommunicator struct { + Stderr io.Reader + Stdout io.Reader +} + +func (c *TestCommunicator) Start(rc *RemoteCmd) error { + go func() { + if rc.Stdout != nil && c.Stdout != nil { + io.Copy(rc.Stdout, c.Stdout) + } + + if rc.Stderr != nil && c.Stderr != nil { + io.Copy(rc.Stderr, c.Stderr) + } + }() + + return nil +} + +func (c *TestCommunicator) Upload(string, io.Reader) error { + return nil +} + +func (c *TestCommunicator) Download(string, io.Writer) error { + return nil +} + +func TestRemoteCmd_StartWithUi(t *testing.T) { + data := "hello\nworld\nthere" + + rcOutput := new(bytes.Buffer) + uiOutput := new(bytes.Buffer) + rcOutput.WriteString(data) + + testComm := &TestCommunicator{ + Stdout: rcOutput, + } + + testUi := &ReaderWriterUi{ + Reader: new(bytes.Buffer), + Writer: uiOutput, + } + + rc := &RemoteCmd{ + Command: "test", + } + + go func() { + time.Sleep(100 * time.Millisecond) + rc.Exited = true + }() + + err := rc.StartWithUi(testComm, testUi) + if err != nil { + t.Fatalf("err: %s", err) + } + + if uiOutput.String() != strings.TrimSpace(data)+"\n" { + t.Fatalf("bad output: '%s'", uiOutput.String()) + } +} + func TestRemoteCmd_Wait(t *testing.T) { var cmd RemoteCmd