diff --git a/communicator/ssh/communicator.go b/communicator/ssh/communicator.go index 3332108d6..825dc3ab2 100644 --- a/communicator/ssh/communicator.go +++ b/communicator/ssh/communicator.go @@ -95,38 +95,6 @@ func (c *comm) Start(cmd *packer.RemoteCmd) (err error) { } func (c *comm) Upload(path string, input io.Reader) error { - session, err := c.newSession() - if err != nil { - return err - } - - defer session.Close() - - // Get a pipe to stdin so that we can send data down - w, err := session.StdinPipe() - if err != nil { - return err - } - - // We only want to close once, so we nil w after we close it, - // and only close in the defer if it hasn't been closed already. - defer func() { - if w != nil { - w.Close() - } - }() - - // Get a pipe to stdout so that we can get responses back - stdoutPipe, err := session.StdoutPipe() - if err != nil { - return err - } - stdoutR := bufio.NewReader(stdoutPipe) - - // Set stderr to a bytes buffer - stderr := new(bytes.Buffer) - session.Stderr = stderr - // The target directory and file for talking the SCP protocol target_dir := filepath.Dir(path) target_file := filepath.Base(path) @@ -136,72 +104,38 @@ func (c *comm) Upload(path string, input io.Reader) error { // which works for unix and windows target_dir = filepath.ToSlash(target_dir) - // Start the sink mode on the other side - // TODO(mitchellh): There are probably issues with shell escaping the path - log.Println("Starting remote scp process in sink mode") - if err = session.Start("scp -vt " + target_dir); err != nil { - return err - } - - // Determine the length of the upload content by copying it - // into an in-memory buffer. Note that this means what we upload - // must fit into memory. - log.Println("Copying input data into in-memory buffer so we can get the length") - input_memory := new(bytes.Buffer) - if _, err = io.Copy(input_memory, input); err != nil { - return err - } - - // Start the protocol - log.Println("Beginning file upload...") - fmt.Fprintln(w, "C0644", input_memory.Len(), target_file) - err = checkSCPStatus(stdoutR) - if err != nil { - return err - } - - if _, err := io.Copy(w, input_memory); err != nil { - return err - } - - fmt.Fprint(w, "\x00") - err = checkSCPStatus(stdoutR) - if err != nil { - return err - } - - // Close the stdin, which sends an EOF, and then set w to nil so that - // our defer func doesn't close it again since that is unsafe with - // the Go SSH package. - log.Println("Upload complete, closing stdin pipe") - w.Close() - w = nil - - // Wait for the SCP connection to close, meaning it has consumed all - // our data and has completed. Or has errored. - log.Println("Waiting for SSH session to complete") - err = session.Wait() - if err != nil { - if exitErr, ok := err.(*ssh.ExitError); ok { - // Otherwise, we have an ExitErorr, meaning we can just read - // the exit status - log.Printf("non-zero exit status: %d", exitErr.ExitStatus()) - - // If we exited with status 127, it means SCP isn't available. - // Return a more descriptive error for that. - if exitErr.ExitStatus() == 127 { - return errors.New( - "SCP failed to start. This usually means that SCP is not\n" + - "properly installed on the remote system.") - } + scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { + // Determine the length of the upload content by copying it + // into an in-memory buffer. Note that this means what we upload + // must fit into memory. + log.Println("Copying input data into in-memory buffer so we can get the length") + input_memory := new(bytes.Buffer) + if _, err := io.Copy(input_memory, input); err != nil { + return err } - return err + // Start the protocol + log.Println("Beginning file upload...") + fmt.Fprintln(w, "C0644", input_memory.Len(), target_file) + err := checkSCPStatus(stdoutR) + if err != nil { + return err + } + + if _, err := io.Copy(w, input_memory); err != nil { + return err + } + + fmt.Fprint(w, "\x00") + err = checkSCPStatus(stdoutR) + if err != nil { + return err + } + + return nil } - log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String()) - - return nil + return c.scpSession("scp -vt "+target_dir, scpFunc) } func (c *comm) UploadDir(dst string, src string, excl []string) error { @@ -257,6 +191,84 @@ func (c *comm) reconnect() (err error) { return } +func (c *comm) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error { + session, err := c.newSession() + if err != nil { + return err + } + defer session.Close() + + // Get a pipe to stdin so that we can send data down + stdinW, err := session.StdinPipe() + if err != nil { + return err + } + + // We only want to close once, so we nil w after we close it, + // and only close in the defer if it hasn't been closed already. + defer func() { + if stdinW != nil { + stdinW.Close() + } + }() + + // Get a pipe to stdout so that we can get responses back + stdoutPipe, err := session.StdoutPipe() + if err != nil { + return err + } + stdoutR := bufio.NewReader(stdoutPipe) + + // Set stderr to a bytes buffer + stderr := new(bytes.Buffer) + session.Stderr = stderr + + // Start the sink mode on the other side + // TODO(mitchellh): There are probably issues with shell escaping the path + log.Println("Starting remote scp process: %s", scpCommand) + if err := session.Start(scpCommand); err != nil { + return err + } + + // Call our callback that executes in the context of SCP + log.Println("Started SCP session, beginning transfers...") + if err := f(stdinW, stdoutR); err != nil { + return err + } + + // Close the stdin, which sends an EOF, and then set w to nil so that + // our defer func doesn't close it again since that is unsafe with + // the Go SSH package. + log.Println("SCP session complete, closing stdin pipe.") + stdinW.Close() + stdinW = nil + + // Wait for the SCP connection to close, meaning it has consumed all + // our data and has completed. Or has errored. + log.Println("Waiting for SSH session to complete.") + err = session.Wait() + if err != nil { + if exitErr, ok := err.(*ssh.ExitError); ok { + // Otherwise, we have an ExitErorr, meaning we can just read + // the exit status + log.Printf("non-zero exit status: %d", exitErr.ExitStatus()) + + // If we exited with status 127, it means SCP isn't available. + // Return a more descriptive error for that. + if exitErr.ExitStatus() == 127 { + return errors.New( + "SCP failed to start. This usually means that SCP is not\n" + + "properly installed on the remote system.") + } + } + + return err + } + + log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String()) + return nil +} + // checkSCPStatus checks that a prior command sent to SCP completed // successfully. If it did not complete successfully, an error will // be returned.