communicator/ssh: refactor scpSession that we'll reuse for UploadDir

This commit is contained in:
Mitchell Hashimoto 2013-08-23 19:52:02 -07:00
parent d857c9ccbb
commit 05ab50949f
1 changed files with 106 additions and 94 deletions

View File

@ -95,38 +95,6 @@ func (c *comm) Start(cmd *packer.RemoteCmd) (err error) {
} }
func (c *comm) Upload(path string, input io.Reader) error { func (c *comm) Upload(path string, input io.Reader) error {
session, err := c.newSession()
if err != nil {
return err
}
defer session.Close()
// Get a pipe to stdin so that we can send data down
w, err := session.StdinPipe()
if err != nil {
return err
}
// We only want to close once, so we nil w after we close it,
// and only close in the defer if it hasn't been closed already.
defer func() {
if w != nil {
w.Close()
}
}()
// Get a pipe to stdout so that we can get responses back
stdoutPipe, err := session.StdoutPipe()
if err != nil {
return err
}
stdoutR := bufio.NewReader(stdoutPipe)
// Set stderr to a bytes buffer
stderr := new(bytes.Buffer)
session.Stderr = stderr
// The target directory and file for talking the SCP protocol // The target directory and file for talking the SCP protocol
target_dir := filepath.Dir(path) target_dir := filepath.Dir(path)
target_file := filepath.Base(path) target_file := filepath.Base(path)
@ -136,72 +104,38 @@ func (c *comm) Upload(path string, input io.Reader) error {
// which works for unix and windows // which works for unix and windows
target_dir = filepath.ToSlash(target_dir) target_dir = filepath.ToSlash(target_dir)
// Start the sink mode on the other side scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
// TODO(mitchellh): There are probably issues with shell escaping the path // Determine the length of the upload content by copying it
log.Println("Starting remote scp process in sink mode") // into an in-memory buffer. Note that this means what we upload
if err = session.Start("scp -vt " + target_dir); err != nil { // must fit into memory.
return err log.Println("Copying input data into in-memory buffer so we can get the length")
} input_memory := new(bytes.Buffer)
if _, err := io.Copy(input_memory, input); err != nil {
// Determine the length of the upload content by copying it return err
// into an in-memory buffer. Note that this means what we upload
// must fit into memory.
log.Println("Copying input data into in-memory buffer so we can get the length")
input_memory := new(bytes.Buffer)
if _, err = io.Copy(input_memory, input); err != nil {
return err
}
// Start the protocol
log.Println("Beginning file upload...")
fmt.Fprintln(w, "C0644", input_memory.Len(), target_file)
err = checkSCPStatus(stdoutR)
if err != nil {
return err
}
if _, err := io.Copy(w, input_memory); err != nil {
return err
}
fmt.Fprint(w, "\x00")
err = checkSCPStatus(stdoutR)
if err != nil {
return err
}
// Close the stdin, which sends an EOF, and then set w to nil so that
// our defer func doesn't close it again since that is unsafe with
// the Go SSH package.
log.Println("Upload complete, closing stdin pipe")
w.Close()
w = nil
// Wait for the SCP connection to close, meaning it has consumed all
// our data and has completed. Or has errored.
log.Println("Waiting for SSH session to complete")
err = session.Wait()
if err != nil {
if exitErr, ok := err.(*ssh.ExitError); ok {
// Otherwise, we have an ExitErorr, meaning we can just read
// the exit status
log.Printf("non-zero exit status: %d", exitErr.ExitStatus())
// If we exited with status 127, it means SCP isn't available.
// Return a more descriptive error for that.
if exitErr.ExitStatus() == 127 {
return errors.New(
"SCP failed to start. This usually means that SCP is not\n" +
"properly installed on the remote system.")
}
} }
return err // Start the protocol
log.Println("Beginning file upload...")
fmt.Fprintln(w, "C0644", input_memory.Len(), target_file)
err := checkSCPStatus(stdoutR)
if err != nil {
return err
}
if _, err := io.Copy(w, input_memory); err != nil {
return err
}
fmt.Fprint(w, "\x00")
err = checkSCPStatus(stdoutR)
if err != nil {
return err
}
return nil
} }
log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String()) return c.scpSession("scp -vt "+target_dir, scpFunc)
return nil
} }
func (c *comm) UploadDir(dst string, src string, excl []string) error { func (c *comm) UploadDir(dst string, src string, excl []string) error {
@ -257,6 +191,84 @@ func (c *comm) reconnect() (err error) {
return return
} }
func (c *comm) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error {
session, err := c.newSession()
if err != nil {
return err
}
defer session.Close()
// Get a pipe to stdin so that we can send data down
stdinW, err := session.StdinPipe()
if err != nil {
return err
}
// We only want to close once, so we nil w after we close it,
// and only close in the defer if it hasn't been closed already.
defer func() {
if stdinW != nil {
stdinW.Close()
}
}()
// Get a pipe to stdout so that we can get responses back
stdoutPipe, err := session.StdoutPipe()
if err != nil {
return err
}
stdoutR := bufio.NewReader(stdoutPipe)
// Set stderr to a bytes buffer
stderr := new(bytes.Buffer)
session.Stderr = stderr
// Start the sink mode on the other side
// TODO(mitchellh): There are probably issues with shell escaping the path
log.Println("Starting remote scp process: %s", scpCommand)
if err := session.Start(scpCommand); err != nil {
return err
}
// Call our callback that executes in the context of SCP
log.Println("Started SCP session, beginning transfers...")
if err := f(stdinW, stdoutR); err != nil {
return err
}
// Close the stdin, which sends an EOF, and then set w to nil so that
// our defer func doesn't close it again since that is unsafe with
// the Go SSH package.
log.Println("SCP session complete, closing stdin pipe.")
stdinW.Close()
stdinW = nil
// Wait for the SCP connection to close, meaning it has consumed all
// our data and has completed. Or has errored.
log.Println("Waiting for SSH session to complete.")
err = session.Wait()
if err != nil {
if exitErr, ok := err.(*ssh.ExitError); ok {
// Otherwise, we have an ExitErorr, meaning we can just read
// the exit status
log.Printf("non-zero exit status: %d", exitErr.ExitStatus())
// If we exited with status 127, it means SCP isn't available.
// Return a more descriptive error for that.
if exitErr.ExitStatus() == 127 {
return errors.New(
"SCP failed to start. This usually means that SCP is not\n" +
"properly installed on the remote system.")
}
}
return err
}
log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String())
return nil
}
// checkSCPStatus checks that a prior command sent to SCP completed // checkSCPStatus checks that a prior command sent to SCP completed
// successfully. If it did not complete successfully, an error will // successfully. If it did not complete successfully, an error will
// be returned. // be returned.