communicator/ssh: refactor scpSession that we'll reuse for UploadDir
This commit is contained in:
parent
7895df8c8f
commit
e75d3c1fbb
|
@ -95,38 +95,6 @@ func (c *comm) Start(cmd *packer.RemoteCmd) (err error) {
|
|||
}
|
||||
|
||||
func (c *comm) Upload(path string, input io.Reader) error {
|
||||
session, err := c.newSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer session.Close()
|
||||
|
||||
// Get a pipe to stdin so that we can send data down
|
||||
w, err := session.StdinPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// We only want to close once, so we nil w after we close it,
|
||||
// and only close in the defer if it hasn't been closed already.
|
||||
defer func() {
|
||||
if w != nil {
|
||||
w.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// Get a pipe to stdout so that we can get responses back
|
||||
stdoutPipe, err := session.StdoutPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stdoutR := bufio.NewReader(stdoutPipe)
|
||||
|
||||
// Set stderr to a bytes buffer
|
||||
stderr := new(bytes.Buffer)
|
||||
session.Stderr = stderr
|
||||
|
||||
// The target directory and file for talking the SCP protocol
|
||||
target_dir := filepath.Dir(path)
|
||||
target_file := filepath.Base(path)
|
||||
|
@ -136,72 +104,38 @@ func (c *comm) Upload(path string, input io.Reader) error {
|
|||
// which works for unix and windows
|
||||
target_dir = filepath.ToSlash(target_dir)
|
||||
|
||||
// Start the sink mode on the other side
|
||||
// TODO(mitchellh): There are probably issues with shell escaping the path
|
||||
log.Println("Starting remote scp process in sink mode")
|
||||
if err = session.Start("scp -vt " + target_dir); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Determine the length of the upload content by copying it
|
||||
// into an in-memory buffer. Note that this means what we upload
|
||||
// must fit into memory.
|
||||
log.Println("Copying input data into in-memory buffer so we can get the length")
|
||||
input_memory := new(bytes.Buffer)
|
||||
if _, err = io.Copy(input_memory, input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Start the protocol
|
||||
log.Println("Beginning file upload...")
|
||||
fmt.Fprintln(w, "C0644", input_memory.Len(), target_file)
|
||||
err = checkSCPStatus(stdoutR)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := io.Copy(w, input_memory); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Fprint(w, "\x00")
|
||||
err = checkSCPStatus(stdoutR)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Close the stdin, which sends an EOF, and then set w to nil so that
|
||||
// our defer func doesn't close it again since that is unsafe with
|
||||
// the Go SSH package.
|
||||
log.Println("Upload complete, closing stdin pipe")
|
||||
w.Close()
|
||||
w = nil
|
||||
|
||||
// Wait for the SCP connection to close, meaning it has consumed all
|
||||
// our data and has completed. Or has errored.
|
||||
log.Println("Waiting for SSH session to complete")
|
||||
err = session.Wait()
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*ssh.ExitError); ok {
|
||||
// Otherwise, we have an ExitErorr, meaning we can just read
|
||||
// the exit status
|
||||
log.Printf("non-zero exit status: %d", exitErr.ExitStatus())
|
||||
|
||||
// If we exited with status 127, it means SCP isn't available.
|
||||
// Return a more descriptive error for that.
|
||||
if exitErr.ExitStatus() == 127 {
|
||||
return errors.New(
|
||||
"SCP failed to start. This usually means that SCP is not\n" +
|
||||
"properly installed on the remote system.")
|
||||
}
|
||||
scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
|
||||
// Determine the length of the upload content by copying it
|
||||
// into an in-memory buffer. Note that this means what we upload
|
||||
// must fit into memory.
|
||||
log.Println("Copying input data into in-memory buffer so we can get the length")
|
||||
input_memory := new(bytes.Buffer)
|
||||
if _, err := io.Copy(input_memory, input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
// Start the protocol
|
||||
log.Println("Beginning file upload...")
|
||||
fmt.Fprintln(w, "C0644", input_memory.Len(), target_file)
|
||||
err := checkSCPStatus(stdoutR)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := io.Copy(w, input_memory); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Fprint(w, "\x00")
|
||||
err = checkSCPStatus(stdoutR)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String())
|
||||
|
||||
return nil
|
||||
return c.scpSession("scp -vt "+target_dir, scpFunc)
|
||||
}
|
||||
|
||||
func (c *comm) UploadDir(dst string, src string, excl []string) error {
|
||||
|
@ -257,6 +191,84 @@ func (c *comm) reconnect() (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *comm) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error {
|
||||
session, err := c.newSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
// Get a pipe to stdin so that we can send data down
|
||||
stdinW, err := session.StdinPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// We only want to close once, so we nil w after we close it,
|
||||
// and only close in the defer if it hasn't been closed already.
|
||||
defer func() {
|
||||
if stdinW != nil {
|
||||
stdinW.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// Get a pipe to stdout so that we can get responses back
|
||||
stdoutPipe, err := session.StdoutPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stdoutR := bufio.NewReader(stdoutPipe)
|
||||
|
||||
// Set stderr to a bytes buffer
|
||||
stderr := new(bytes.Buffer)
|
||||
session.Stderr = stderr
|
||||
|
||||
// Start the sink mode on the other side
|
||||
// TODO(mitchellh): There are probably issues with shell escaping the path
|
||||
log.Println("Starting remote scp process: %s", scpCommand)
|
||||
if err := session.Start(scpCommand); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Call our callback that executes in the context of SCP
|
||||
log.Println("Started SCP session, beginning transfers...")
|
||||
if err := f(stdinW, stdoutR); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Close the stdin, which sends an EOF, and then set w to nil so that
|
||||
// our defer func doesn't close it again since that is unsafe with
|
||||
// the Go SSH package.
|
||||
log.Println("SCP session complete, closing stdin pipe.")
|
||||
stdinW.Close()
|
||||
stdinW = nil
|
||||
|
||||
// Wait for the SCP connection to close, meaning it has consumed all
|
||||
// our data and has completed. Or has errored.
|
||||
log.Println("Waiting for SSH session to complete.")
|
||||
err = session.Wait()
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*ssh.ExitError); ok {
|
||||
// Otherwise, we have an ExitErorr, meaning we can just read
|
||||
// the exit status
|
||||
log.Printf("non-zero exit status: %d", exitErr.ExitStatus())
|
||||
|
||||
// If we exited with status 127, it means SCP isn't available.
|
||||
// Return a more descriptive error for that.
|
||||
if exitErr.ExitStatus() == 127 {
|
||||
return errors.New(
|
||||
"SCP failed to start. This usually means that SCP is not\n" +
|
||||
"properly installed on the remote system.")
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("scp stderr (length %d): %s", stderr.Len(), stderr.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkSCPStatus checks that a prior command sent to SCP completed
|
||||
// successfully. If it did not complete successfully, an error will
|
||||
// be returned.
|
||||
|
|
Loading…
Reference in New Issue