diff --git a/communicator/ssh/communicator.go b/communicator/ssh/communicator.go index dacd46807..c1fbc5e52 100644 --- a/communicator/ssh/communicator.go +++ b/communicator/ssh/communicator.go @@ -1,6 +1,7 @@ package ssh import ( + "bufio" "bytes" "code.google.com/p/go.crypto/ssh" "errors" @@ -106,12 +107,6 @@ func (c *comm) Upload(path string, input io.Reader) error { return err } - // Set stderr/stdout to a bytes buffer - stderr := new(bytes.Buffer) - stdout := new(bytes.Buffer) - session.Stderr = stderr - session.Stdout = stdout - // We only want to close once, so we nil w after we close it, // and only close in the defer if it hasn't been closed already. defer func() { @@ -120,6 +115,17 @@ func (c *comm) Upload(path string, input io.Reader) error { } }() + // Get a pipe to stdout so that we can get responses back + scp_reader, err := session.StdoutPipe() + if err != nil { + return err + } + r := bufio.NewReader(scp_reader) + + // Set stderr to a bytes buffer + stderr := new(bytes.Buffer) + session.Stderr = stderr + // The target directory and file for talking the SCP protocol target_dir := filepath.Dir(path) target_file := filepath.Base(path) @@ -143,8 +149,17 @@ func (c *comm) Upload(path string, input io.Reader) error { // Start the protocol log.Println("Beginning file upload...") fmt.Fprintln(w, "C0644", input_memory.Len(), target_file) + err = check_response(r) + if err != nil { + return err + } + io.Copy(w, input_memory) fmt.Fprint(w, "\x00") + err = check_response(r) + if err != nil { + return err + } // TODO(mitchellh): Each step above results in a 0/1/2 being sent by // the remote side to confirm. We should check for those confirmations. @@ -178,7 +193,6 @@ func (c *comm) Upload(path string, input io.Reader) error { return err } - log.Printf("scp stdout (length %d): %#v", stdout.Len(), stdout.Bytes()) log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String()) return nil @@ -223,3 +237,17 @@ func (c *comm) reconnect() (err error) { return } + +func check_response(r *bufio.Reader) (err error) { + scp_status_code, err := r.ReadByte() + if err != nil { + return err + } + if scp_status_code != 0 { + // Treat any non-zero (really 1 and 2) as fatal errors + error_message, _, err := r.ReadLine() + err = fmt.Errorf(string(error_message[:])) + return err + } + return nil +}