Merge pull request #195 from markpeek/markpeek-file-error

communicator/ssh: check scp codes and report errors
This commit is contained in:
Mitchell Hashimoto 2013-07-19 11:07:03 -07:00
commit 5e9c51ff6d
2 changed files with 40 additions and 8 deletions

View File

@ -1,6 +1,7 @@
package ssh
import (
"bufio"
"bytes"
"code.google.com/p/go.crypto/ssh"
"errors"
@ -106,12 +107,6 @@ func (c *comm) Upload(path string, input io.Reader) error {
return err
}
// Set stderr/stdout to a bytes buffer
stderr := new(bytes.Buffer)
stdout := new(bytes.Buffer)
session.Stderr = stderr
session.Stdout = stdout
// We only want to close once, so we nil w after we close it,
// and only close in the defer if it hasn't been closed already.
defer func() {
@ -120,6 +115,17 @@ func (c *comm) Upload(path string, input io.Reader) error {
}
}()
// Get a pipe to stdout so that we can get responses back
scp_reader, err := session.StdoutPipe()
if err != nil {
return err
}
r := bufio.NewReader(scp_reader)
// Set stderr to a bytes buffer
stderr := new(bytes.Buffer)
session.Stderr = stderr
// The target directory and file for talking the SCP protocol
target_dir := filepath.Dir(path)
target_file := filepath.Base(path)
@ -143,8 +149,17 @@ func (c *comm) Upload(path string, input io.Reader) error {
// Start the protocol
log.Println("Beginning file upload...")
fmt.Fprintln(w, "C0644", input_memory.Len(), target_file)
err = check_response(r)
if err != nil {
return err
}
io.Copy(w, input_memory)
fmt.Fprint(w, "\x00")
err = check_response(r)
if err != nil {
return err
}
// TODO(mitchellh): Each step above results in a 0/1/2 being sent by
// the remote side to confirm. We should check for those confirmations.
@ -178,7 +193,6 @@ func (c *comm) Upload(path string, input io.Reader) error {
return err
}
log.Printf("scp stdout (length %d): %#v", stdout.Len(), stdout.Bytes())
log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String())
return nil
@ -223,3 +237,17 @@ func (c *comm) reconnect() (err error) {
return
}
func check_response(r *bufio.Reader) (err error) {
scp_status_code, err := r.ReadByte()
if err != nil {
return err
}
if scp_status_code != 0 {
// Treat any non-zero (really 1 and 2) as fatal errors
error_message, _, err := r.ReadLine()
err = fmt.Errorf(string(error_message[:]))
return err
}
return nil
}

View File

@ -79,5 +79,9 @@ func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error {
}
defer f.Close()
return comm.Upload(p.config.Destination, f)
err = comm.Upload(p.config.Destination, f)
if err != nil {
ui.Error(fmt.Sprintf("Upload failed: %s", err))
}
return err
}