communicator/ssh: Trailing slash won't create destination dir

This commit is contained in:
Mitchell Hashimoto 2013-08-25 20:47:10 -07:00
parent 095631107a
commit 86abf14b28
2 changed files with 54 additions and 33 deletions

View File

@ -113,19 +113,30 @@ func (c *comm) Upload(path string, input io.Reader) error {
} }
func (c *comm) UploadDir(dst string, src string, excl []string) error { func (c *comm) UploadDir(dst string, src string, excl []string) error {
f, err := os.Open(src) log.Printf("Upload dir '%s' to '%s'", src, dst)
if err != nil {
return err
}
defer f.Close()
entries, err := f.Readdir(-1)
if err != nil {
return err
}
scpFunc := func(w io.Writer, r *bufio.Reader) error { scpFunc := func(w io.Writer, r *bufio.Reader) error {
return scpUploadDir(src, entries, w, r) uploadEntries := func() error {
f, err := os.Open(src)
if err != nil {
return err
}
defer f.Close()
entries, err := f.Readdir(-1)
if err != nil {
return err
}
return scpUploadDir(src, entries, w, r)
}
if src[len(src)-1] != '/' {
log.Printf("No trailing slash, creating the source directory name")
return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries)
} else {
// Trailing slash, so only upload the contents
return uploadEntries()
}
} }
return c.scpSession("scp -rvt "+dst, scpFunc) return c.scpSession("scp -rvt "+dst, scpFunc)
@ -311,6 +322,26 @@ func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader) erro
return nil return nil
} }
func scpUploadDirProtocol(name string, w io.Writer, r *bufio.Reader, f func() error) error {
log.Printf("SCP: starting directory upload: %s", name)
fmt.Fprintln(w, "D0755 0", name)
err := checkSCPStatus(r)
if err != nil {
return err
}
if err := f(); err != nil {
return err
}
fmt.Fprintln(w, "E")
if err != nil {
return err
}
return nil
}
func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error { func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) error {
for _, fi := range fs { for _, fi := range fs {
realPath := filepath.Join(root, fi.Name()) realPath := filepath.Join(root, fi.Name())
@ -335,21 +366,11 @@ func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) e
} }
// It is a directory, recursively upload // It is a directory, recursively upload
log.Printf("SCP: starting directory upload: %s", fi.Name()) err := scpUploadDirProtocol(fi.Name(), w, r, func() error {
fmt.Fprintln(w, "D0755 0", fi.Name()) f, err := os.Open(realPath)
err := checkSCPStatus(r) if err != nil {
if err != nil { return err
return err }
}
f, err := os.Open(realPath)
if err != nil {
return err
}
// Execute this in a function just so that we have easy "defer"
// available because laziness.
err = func() error {
defer f.Close() defer f.Close()
entries, err := f.Readdir(-1) entries, err := f.Readdir(-1)
@ -358,12 +379,7 @@ func scpUploadDir(root string, fs []os.FileInfo, w io.Writer, r *bufio.Reader) e
} }
return scpUploadDir(realPath, entries, w, r) return scpUploadDir(realPath, entries, w, r)
}() })
if err != nil {
return err
}
fmt.Fprintln(w, "E")
if err != nil { if err != nil {
return err return err
} }

View File

@ -62,6 +62,11 @@ type Communicator interface {
// UploadDir uploads the contents of a directory recursively to // UploadDir uploads the contents of a directory recursively to
// the remote path. It also takes an optional slice of paths to // the remote path. It also takes an optional slice of paths to
// ignore when uploading. // ignore when uploading.
//
// The folder name of the source folder should be created unless there
// is a trailing slash on the source "/". For example: "/tmp/src" as
// the source will create a "src" directory in the destination unless
// a trailing slash is added. This is identical behavior to rsync(1).
UploadDir(string, string, []string) error UploadDir(string, string, []string) error
// Download downloads a file from the machine from the given remote path // Download downloads a file from the machine from the given remote path