From 1948350d201695911dd368222464779f5a3e8757 Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Sat, 24 Aug 2013 17:14:15 -0700 Subject: [PATCH] communicator/ssh: refactor to upload directories properly --- communicator/ssh/communicator.go | 151 +++++++++++++++++++++++++------ 1 file changed, 122 insertions(+), 29 deletions(-) diff --git a/communicator/ssh/communicator.go b/communicator/ssh/communicator.go index 825dc3ab2..5336556b4 100644 --- a/communicator/ssh/communicator.go +++ b/communicator/ssh/communicator.go @@ -10,6 +10,7 @@ import ( "io" "log" "net" + "os" "path/filepath" ) @@ -105,41 +106,29 @@ func (c *comm) Upload(path string, input io.Reader) error { target_dir = filepath.ToSlash(target_dir) scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { - // Determine the length of the upload content by copying it - // into an in-memory buffer. Note that this means what we upload - // must fit into memory. - log.Println("Copying input data into in-memory buffer so we can get the length") - input_memory := new(bytes.Buffer) - if _, err := io.Copy(input_memory, input); err != nil { - return err - } - - // Start the protocol - log.Println("Beginning file upload...") - fmt.Fprintln(w, "C0644", input_memory.Len(), target_file) - err := checkSCPStatus(stdoutR) - if err != nil { - return err - } - - if _, err := io.Copy(w, input_memory); err != nil { - return err - } - - fmt.Fprint(w, "\x00") - err = checkSCPStatus(stdoutR) - if err != nil { - return err - } - - return nil + return scpUploadFile(target_file, input, w, stdoutR) } return c.scpSession("scp -vt "+target_dir, scpFunc) } func (c *comm) UploadDir(dst string, src string, excl []string) error { - return nil + f, err := os.Open(src) + if err != nil { + return err + } + defer f.Close() + + entries, err := f.Readdir(-1) + if err != nil { + return err + } + + scpFunc := func(w io.Writer, r *bufio.Reader) error { + return scpUploadDir(src, entries, w, r) + } + + return c.scpSession("scp -rvt "+dst, scpFunc) } func (c *comm) Download(string, io.Writer) error { @@ -290,3 +279,107 @@ func checkSCPStatus(r *bufio.Reader) error { return nil } + +func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader) error { + // Determine the length of the upload content by copying it + // into an in-memory buffer. Note that this means what we upload + // must fit into memory. + log.Println("Copying input data into in-memory buffer so we can get the length") + inputBuf := new(bytes.Buffer) + if _, err := io.Copy(inputBuf, src); err != nil { + return err + } + + // Start the protocol + log.Println("Beginning file upload...") + fmt.Fprintln(w, "C0644", inputBuf.Len(), dst) + err := checkSCPStatus(r) + if err != nil { + return err + } + + if _, err := io.Copy(w, inputBuf); err != nil { + return err + } + + fmt.Fprint(w, "\x00") + err = checkSCPStatus(r) + if err != nil { + return err + } + + return nil +} + +func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error { + for _, fi := range fs { + if !fi.IsDir() { + // It is a regular file, just upload it + f, err := os.Open(filepath.Join(root, fi.Name())) + if err != nil { + return err + } + + err = func() error { + defer f.Close() + return scpUploadFile(fi.Name(), f, w, r) + }() + + if err != nil { + return err + } + } + } + + return nil +} + +func scpWalkFn(cur string, dst string, src string, w io.Writer, r *bufio.Reader) filepath.WalkFunc { + return func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if path == cur { + // Don't upload ourselves + return nil + } + + // Get the relative path so that we can check excludes and also + // so that we can build the full destination path + relPath, err := filepath.Rel(src, path) + if err != nil { + return err + } + + // TODO(mitchellh): Check excludes + targetPath := filepath.Base(relPath) + if info.IsDir() { + log.Printf("SCP: starting directory upload: %s", targetPath) + fmt.Fprintln(w, "D0755 0", targetPath) + err := checkSCPStatus(r) + if err != nil { + return err + } + + err = filepath.Walk(path, scpWalkFn(path, dst, src, w, r)) + if err != nil { + return err + } + + fmt.Fprintln(w, "E") + return checkSCPStatus(r) + } + + // Open the file for uploading + f, err := os.Open(path) + if err != nil { + return err + } + defer f.Close() + + // Upload the file like any normal SCP operation + targetPath = filepath.Base(relPath) + return scpUploadFile(targetPath, f, w, r) + } +}