communicator/ssh: refactor to upload directories properly
This commit is contained in:
parent
e75d3c1fbb
commit
1948350d20
|
@ -10,6 +10,7 @@ import (
|
|||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
|
@ -105,41 +106,29 @@ func (c *comm) Upload(path string, input io.Reader) error {
|
|||
target_dir = filepath.ToSlash(target_dir)
|
||||
|
||||
scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
|
||||
// Determine the length of the upload content by copying it
|
||||
// into an in-memory buffer. Note that this means what we upload
|
||||
// must fit into memory.
|
||||
log.Println("Copying input data into in-memory buffer so we can get the length")
|
||||
input_memory := new(bytes.Buffer)
|
||||
if _, err := io.Copy(input_memory, input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Start the protocol
|
||||
log.Println("Beginning file upload...")
|
||||
fmt.Fprintln(w, "C0644", input_memory.Len(), target_file)
|
||||
err := checkSCPStatus(stdoutR)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := io.Copy(w, input_memory); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Fprint(w, "\x00")
|
||||
err = checkSCPStatus(stdoutR)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return scpUploadFile(target_file, input, w, stdoutR)
|
||||
}
|
||||
|
||||
return c.scpSession("scp -vt "+target_dir, scpFunc)
|
||||
}
|
||||
|
||||
func (c *comm) UploadDir(dst string, src string, excl []string) error {
|
||||
return nil
|
||||
f, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
entries, err := f.Readdir(-1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
scpFunc := func(w io.Writer, r *bufio.Reader) error {
|
||||
return scpUploadDir(src, entries, w, r)
|
||||
}
|
||||
|
||||
return c.scpSession("scp -rvt "+dst, scpFunc)
|
||||
}
|
||||
|
||||
func (c *comm) Download(string, io.Writer) error {
|
||||
|
@ -290,3 +279,107 @@ func checkSCPStatus(r *bufio.Reader) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader) error {
|
||||
// Determine the length of the upload content by copying it
|
||||
// into an in-memory buffer. Note that this means what we upload
|
||||
// must fit into memory.
|
||||
log.Println("Copying input data into in-memory buffer so we can get the length")
|
||||
inputBuf := new(bytes.Buffer)
|
||||
if _, err := io.Copy(inputBuf, src); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Start the protocol
|
||||
log.Println("Beginning file upload...")
|
||||
fmt.Fprintln(w, "C0644", inputBuf.Len(), dst)
|
||||
err := checkSCPStatus(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := io.Copy(w, inputBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Fprint(w, "\x00")
|
||||
err = checkSCPStatus(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error {
|
||||
for _, fi := range fs {
|
||||
if !fi.IsDir() {
|
||||
// It is a regular file, just upload it
|
||||
f, err := os.Open(filepath.Join(root, fi.Name()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = func() error {
|
||||
defer f.Close()
|
||||
return scpUploadFile(fi.Name(), f, w, r)
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func scpWalkFn(cur string, dst string, src string, w io.Writer, r *bufio.Reader) filepath.WalkFunc {
|
||||
return func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if path == cur {
|
||||
// Don't upload ourselves
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the relative path so that we can check excludes and also
|
||||
// so that we can build the full destination path
|
||||
relPath, err := filepath.Rel(src, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(mitchellh): Check excludes
|
||||
targetPath := filepath.Base(relPath)
|
||||
if info.IsDir() {
|
||||
log.Printf("SCP: starting directory upload: %s", targetPath)
|
||||
fmt.Fprintln(w, "D0755 0", targetPath)
|
||||
err := checkSCPStatus(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = filepath.Walk(path, scpWalkFn(path, dst, src, w, r))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Fprintln(w, "E")
|
||||
return checkSCPStatus(r)
|
||||
}
|
||||
|
||||
// Open the file for uploading
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Upload the file like any normal SCP operation
|
||||
targetPath = filepath.Base(relPath)
|
||||
return scpUploadFile(targetPath, f, w, r)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue