communicator/ssh: refactor to upload directories properly

This commit is contained in:
Mitchell Hashimoto 2013-08-24 17:14:15 -07:00
parent 05ab50949f
commit a050d344eb
1 changed files with 122 additions and 29 deletions

View File

@ -10,6 +10,7 @@ import (
"io"
"log"
"net"
"os"
"path/filepath"
)
@ -105,41 +106,29 @@ func (c *comm) Upload(path string, input io.Reader) error {
target_dir = filepath.ToSlash(target_dir)
scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
// Determine the length of the upload content by copying it
// into an in-memory buffer. Note that this means what we upload
// must fit into memory.
log.Println("Copying input data into in-memory buffer so we can get the length")
input_memory := new(bytes.Buffer)
if _, err := io.Copy(input_memory, input); err != nil {
return err
}
// Start the protocol
log.Println("Beginning file upload...")
fmt.Fprintln(w, "C0644", input_memory.Len(), target_file)
err := checkSCPStatus(stdoutR)
if err != nil {
return err
}
if _, err := io.Copy(w, input_memory); err != nil {
return err
}
fmt.Fprint(w, "\x00")
err = checkSCPStatus(stdoutR)
if err != nil {
return err
}
return nil
return scpUploadFile(target_file, input, w, stdoutR)
}
return c.scpSession("scp -vt "+target_dir, scpFunc)
}
func (c *comm) UploadDir(dst string, src string, excl []string) error {
return nil
f, err := os.Open(src)
if err != nil {
return err
}
defer f.Close()
entries, err := f.Readdir(-1)
if err != nil {
return err
}
scpFunc := func(w io.Writer, r *bufio.Reader) error {
return scpUploadDir(src, entries, w, r)
}
return c.scpSession("scp -rvt "+dst, scpFunc)
}
func (c *comm) Download(string, io.Writer) error {
@ -290,3 +279,107 @@ func checkSCPStatus(r *bufio.Reader) error {
return nil
}
func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader) error {
// Determine the length of the upload content by copying it
// into an in-memory buffer. Note that this means what we upload
// must fit into memory.
log.Println("Copying input data into in-memory buffer so we can get the length")
inputBuf := new(bytes.Buffer)
if _, err := io.Copy(inputBuf, src); err != nil {
return err
}
// Start the protocol
log.Println("Beginning file upload...")
fmt.Fprintln(w, "C0644", inputBuf.Len(), dst)
err := checkSCPStatus(r)
if err != nil {
return err
}
if _, err := io.Copy(w, inputBuf); err != nil {
return err
}
fmt.Fprint(w, "\x00")
err = checkSCPStatus(r)
if err != nil {
return err
}
return nil
}
func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error {
for _, fi := range fs {
if !fi.IsDir() {
// It is a regular file, just upload it
f, err := os.Open(filepath.Join(root, fi.Name()))
if err != nil {
return err
}
err = func() error {
defer f.Close()
return scpUploadFile(fi.Name(), f, w, r)
}()
if err != nil {
return err
}
}
}
return nil
}
func scpWalkFn(cur string, dst string, src string, w io.Writer, r *bufio.Reader) filepath.WalkFunc {
return func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if path == cur {
// Don't upload ourselves
return nil
}
// Get the relative path so that we can check excludes and also
// so that we can build the full destination path
relPath, err := filepath.Rel(src, path)
if err != nil {
return err
}
// TODO(mitchellh): Check excludes
targetPath := filepath.Base(relPath)
if info.IsDir() {
log.Printf("SCP: starting directory upload: %s", targetPath)
fmt.Fprintln(w, "D0755 0", targetPath)
err := checkSCPStatus(r)
if err != nil {
return err
}
err = filepath.Walk(path, scpWalkFn(path, dst, src, w, r))
if err != nil {
return err
}
fmt.Fprintln(w, "E")
return checkSCPStatus(r)
}
// Open the file for uploading
f, err := os.Open(path)
if err != nil {
return err
}
defer f.Close()
// Upload the file like any normal SCP operation
targetPath = filepath.Base(relPath)
return scpUploadFile(targetPath, f, w, r)
}
}