From a59c82d7a6f4f82e362203f3086c51d58b4ae78a Mon Sep 17 00:00:00 2001 From: Alfonso Acosta Date: Sun, 26 Jul 2015 23:39:56 +0000 Subject: [PATCH] Add sftp file transfer support Adds a new config option: "ssh_file_transfer_method", which can be set to "scp" or "sftp" (defaults to "scp") --- communicator/ssh/communicator.go | 365 +++++++++++++----- helper/communicator/config.go | 39 +- helper/communicator/step_connect_ssh.go | 1 + .../dir_no_trailing_sftp.json | 23 ++ .../dir_with_trailing_sftp.json | 23 ++ test/fixtures/provisioner-file/file_sftp.json | 23 ++ test/provisioner_file.bats | 18 + .../docs/templates/communicator.html.md | 3 + 8 files changed, 389 insertions(+), 106 deletions(-) create mode 100644 test/fixtures/provisioner-file/dir_no_trailing_sftp.json create mode 100644 test/fixtures/provisioner-file/dir_with_trailing_sftp.json create mode 100644 test/fixtures/provisioner-file/file_sftp.json diff --git a/communicator/ssh/communicator.go b/communicator/ssh/communicator.go index 0b2acd8f8..19302e89a 100644 --- a/communicator/ssh/communicator.go +++ b/communicator/ssh/communicator.go @@ -16,6 +16,7 @@ import ( "time" "github.com/mitchellh/packer/packer" + "github.com/pkg/sftp" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" ) @@ -51,6 +52,9 @@ type Config struct { // HandshakeTimeout limits the amount of time we'll wait to handshake before // saying the connection failed. HandshakeTimeout time.Duration + + // UseSftp, if true, sftp will be used instead of scp for file transfers + UseSftp bool } // Creates a new packer.Communicator implementation over SSH. This takes @@ -136,107 +140,28 @@ func (c *comm) Start(cmd *packer.RemoteCmd) (err error) { } func (c *comm) Upload(path string, input io.Reader, fi *os.FileInfo) error { - // The target directory and file for talking the SCP protocol - target_dir := filepath.Dir(path) - target_file := filepath.Base(path) - - // On windows, filepath.Dir uses backslash seperators (ie. "\tmp"). - // This does not work when the target host is unix. Switch to forward slash - // which works for unix and windows - target_dir = filepath.ToSlash(target_dir) - - scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { - return scpUploadFile(target_file, input, w, stdoutR, fi) + if c.config.UseSftp { + return c.sftpUploadSession(path, input, fi) + } else { + return c.scpUploadSession(path, input, fi) } - - return c.scpSession("scp -vt "+target_dir, scpFunc) } func (c *comm) UploadDir(dst string, src string, excl []string) error { log.Printf("Upload dir '%s' to '%s'", src, dst) - scpFunc := func(w io.Writer, r *bufio.Reader) error { - uploadEntries := func() error { - f, err := os.Open(src) - if err != nil { - return err - } - defer f.Close() - - entries, err := f.Readdir(-1) - if err != nil { - return err - } - - return scpUploadDir(src, entries, w, r) - } - - if src[len(src)-1] != '/' { - log.Printf("No trailing slash, creating the source directory name") - fi, err := os.Stat(src) - if err != nil { - return err - } - return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries, fi) - } else { - // Trailing slash, so only upload the contents - return uploadEntries() - } + if c.config.UseSftp { + return c.sftpUploadDirSession(dst, src, excl) + } else { + return c.scpUploadDirSession(dst, src, excl) } - - return c.scpSession("scp -rvt "+dst, scpFunc) } func (c *comm) Download(path string, output io.Writer) error { - scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { - fmt.Fprint(w, "\x00") - - // read file info - fi, err := stdoutR.ReadString('\n') - if err != nil { - return err - } - - if len(fi) < 0 { - return fmt.Errorf("empty response from server") - } - - switch fi[0] { - case '\x01', '\x02': - return fmt.Errorf("%s", fi[1:len(fi)]) - case 'C': - case 'D': - return fmt.Errorf("remote file is directory") - default: - return fmt.Errorf("unexpected server response (%x)", fi[0]) - } - - var mode string - var size int64 - - n, err := fmt.Sscanf(fi, "%6s %d ", &mode, &size) - if err != nil || n != 2 { - return fmt.Errorf("can't parse server response (%s)", fi) - } - if size < 0 { - return fmt.Errorf("negative file size") - } - - fmt.Fprint(w, "\x00") - - if _, err := io.CopyN(output, stdoutR, size); err != nil { - return err - } - - fmt.Fprint(w, "\x00") - - if err := checkSCPStatus(stdoutR); err != nil { - return err - } - - return nil + if c.config.UseSftp { + return c.sftpDownloadSession(path, output) + } else { + return c.scpDownloadSession(path, output) } - - return c.scpSession("scp -vf "+strconv.Quote(path), scpFunc) } func (c *comm) newSession() (session *ssh.Session, err error) { @@ -385,6 +310,262 @@ func (c *comm) connectToAgent() { return } +func (c *comm) sftpUploadSession(path string, input io.Reader, fi *os.FileInfo) error { + sftpFunc := func(client *sftp.Client) error { + return sftpUploadFile(path, input, client, fi) + } + + return c.sftpSession(sftpFunc) +} + +func sftpUploadFile(path string, input io.Reader, client *sftp.Client, fi *os.FileInfo) error { + log.Printf("[DEBUG] sftp: uploading %s", path) + + f, err := client.Create(path) + if err != nil { + return err + } + defer f.Close() + + if _, err = io.Copy(f, input); err != nil { + return err + } + + if fi != nil && (*fi).Mode().IsRegular() { + mode := (*fi).Mode().Perm() + err = client.Chmod(path, mode) + if err != nil { + return err + } + } + + return nil +} + +func (c *comm) sftpUploadDirSession(dst string, src string, excl []string) error { + sftpFunc := func(client *sftp.Client) error { + rootDst := dst + if src[len(src)-1] != '/' { + log.Printf("No trailing slash, creating the source directory name") + rootDst = filepath.Join(dst, filepath.Base(src)) + } + walkFunc := func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + // Calculate the final destination using the + // base source and root destination + relSrc, err := filepath.Rel(src, path) + if err != nil { + return err + } + finalDst := filepath.Join(rootDst, relSrc) + + // In Windows, Join uses backslashes which we don't want to get + // to the sftp server + finalDst = filepath.ToSlash(finalDst) + + // Skip the creation of the target destination directory since + // it should exist and we might not even own it + if finalDst == dst { + return nil + } + + return sftpVisitFile(finalDst, path, info, client) + } + + return filepath.Walk(src, walkFunc) + } + + return c.sftpSession(sftpFunc) +} + +func sftpMkdir(path string, client *sftp.Client, fi os.FileInfo) error { + log.Printf("[DEBUG] sftp: creating dir %s", path) + + if err := client.Mkdir(path); err != nil { + // Do not consider it an error if the directory existed + remoteFi, fiErr := client.Lstat(path) + if fiErr != nil || !remoteFi.IsDir() { + return err + } + } + + mode := fi.Mode().Perm() + if err := client.Chmod(path, mode); err != nil { + return err + } + return nil +} + +func sftpVisitFile(dst string, src string, fi os.FileInfo, client *sftp.Client) error { + if !fi.IsDir() { + f, err := os.Open(src) + if err != nil { + return err + } + defer f.Close() + return sftpUploadFile(dst, f, client, &fi) + } else { + err := sftpMkdir(dst, client, fi) + return err + } +} + +func (c *comm) sftpDownloadSession(path string, output io.Writer) error { + sftpFunc := func(client *sftp.Client) error { + f, err := client.Open(path) + if err != nil { + return err + } + defer f.Close() + + if _, err = io.Copy(output, f); err != nil { + return err + } + + return nil + } + + return c.sftpSession(sftpFunc) +} + +func (c *comm) sftpSession(f func(*sftp.Client) error) error { + client, err := c.newSftpClient() + if err != nil { + return err + } + defer client.Close() + + return f(client) +} + +func (c *comm) newSftpClient() (*sftp.Client, error) { + session, err := c.newSession() + if err != nil { + return nil, err + } + + if err := session.RequestSubsystem("sftp"); err != nil { + return nil, err + } + + pw, err := session.StdinPipe() + if err != nil { + return nil, err + } + pr, err := session.StdoutPipe() + if err != nil { + return nil, err + } + + return sftp.NewClientPipe(pr, pw) +} + +func (c *comm) scpUploadSession(path string, input io.Reader, fi *os.FileInfo) error { + + // The target directory and file for talking the SCP protocol + target_dir := filepath.Dir(path) + target_file := filepath.Base(path) + + // On windows, filepath.Dir uses backslash seperators (ie. "\tmp"). + // This does not work when the target host is unix. Switch to forward slash + // which works for unix and windows + target_dir = filepath.ToSlash(target_dir) + + scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { + return scpUploadFile(target_file, input, w, stdoutR, fi) + } + + return c.scpSession("scp -vt "+target_dir, scpFunc) +} + +func (c *comm) scpUploadDirSession(dst string, src string, excl []string) error { + scpFunc := func(w io.Writer, r *bufio.Reader) error { + uploadEntries := func() error { + f, err := os.Open(src) + if err != nil { + return err + } + defer f.Close() + + entries, err := f.Readdir(-1) + if err != nil { + return err + } + + return scpUploadDir(src, entries, w, r) + } + + if src[len(src)-1] != '/' { + log.Printf("No trailing slash, creating the source directory name") + fi, err := os.Stat(src) + if err != nil { + return err + } + return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries, fi) + } else { + // Trailing slash, so only upload the contents + return uploadEntries() + } + } + + return c.scpSession("scp -rvt "+dst, scpFunc) +} + +func (c *comm) scpDownloadSession(path string, output io.Writer) error { + scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { + fmt.Fprint(w, "\x00") + + // read file info + fi, err := stdoutR.ReadString('\n') + if err != nil { + return err + } + + if len(fi) < 0 { + return fmt.Errorf("empty response from server") + } + + switch fi[0] { + case '\x01', '\x02': + return fmt.Errorf("%s", fi[1:len(fi)]) + case 'C': + case 'D': + return fmt.Errorf("remote file is directory") + default: + return fmt.Errorf("unexpected server response (%x)", fi[0]) + } + + var mode string + var size int64 + + n, err := fmt.Sscanf(fi, "%6s %d ", &mode, &size) + if err != nil || n != 2 { + return fmt.Errorf("can't parse server response (%s)", fi) + } + if size < 0 { + return fmt.Errorf("negative file size") + } + + fmt.Fprint(w, "\x00") + + if _, err := io.CopyN(output, stdoutR, size); err != nil { + return err + } + + fmt.Fprint(w, "\x00") + + if err := checkSCPStatus(stdoutR); err != nil { + return err + } + + return nil + } + + return c.scpSession("scp -vf "+strconv.Quote(path), scpFunc) +} + func (c *comm) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error { session, err := c.newSession() if err != nil { @@ -533,7 +714,7 @@ func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader, fi * // Start the protocol perms := fmt.Sprintf("C%04o", mode) - log.Printf("[DEBUG] Uploading %s: perms=%s size=%d", dst, perms, size) + log.Printf("[DEBUG] scp: Uploading %s: perms=%s size=%d", dst, perms, size) fmt.Fprintln(w, perms, size, dst) if err := checkSCPStatus(r); err != nil { diff --git a/helper/communicator/config.go b/helper/communicator/config.go index 0f19c4e68..fc57d4316 100644 --- a/helper/communicator/config.go +++ b/helper/communicator/config.go @@ -15,20 +15,21 @@ type Config struct { Type string `mapstructure:"communicator"` // SSH - SSHHost string `mapstructure:"ssh_host"` - SSHPort int `mapstructure:"ssh_port"` - SSHUsername string `mapstructure:"ssh_username"` - SSHPassword string `mapstructure:"ssh_password"` - SSHPrivateKey string `mapstructure:"ssh_private_key_file"` - SSHPty bool `mapstructure:"ssh_pty"` - SSHTimeout time.Duration `mapstructure:"ssh_timeout"` - SSHDisableAgent bool `mapstructure:"ssh_disable_agent"` - SSHHandshakeAttempts int `mapstructure:"ssh_handshake_attempts"` - SSHBastionHost string `mapstructure:"ssh_bastion_host"` - SSHBastionPort int `mapstructure:"ssh_bastion_port"` - SSHBastionUsername string `mapstructure:"ssh_bastion_username"` - SSHBastionPassword string `mapstructure:"ssh_bastion_password"` - SSHBastionPrivateKey string `mapstructure:"ssh_bastion_private_key_file"` + SSHHost string `mapstructure:"ssh_host"` + SSHPort int `mapstructure:"ssh_port"` + SSHUsername string `mapstructure:"ssh_username"` + SSHPassword string `mapstructure:"ssh_password"` + SSHPrivateKey string `mapstructure:"ssh_private_key_file"` + SSHPty bool `mapstructure:"ssh_pty"` + SSHTimeout time.Duration `mapstructure:"ssh_timeout"` + SSHDisableAgent bool `mapstructure:"ssh_disable_agent"` + SSHHandshakeAttempts int `mapstructure:"ssh_handshake_attempts"` + SSHBastionHost string `mapstructure:"ssh_bastion_host"` + SSHBastionPort int `mapstructure:"ssh_bastion_port"` + SSHBastionUsername string `mapstructure:"ssh_bastion_username"` + SSHBastionPassword string `mapstructure:"ssh_bastion_password"` + SSHBastionPrivateKey string `mapstructure:"ssh_bastion_private_key_file"` + SSHFileTransferMethod string `mapstructure:"ssh_file_transfer_method"` // WinRM WinRMUser string `mapstructure:"winrm_username"` @@ -93,6 +94,10 @@ func (c *Config) prepareSSH(ctx *interpolate.Context) []error { } } + if c.SSHFileTransferMethod == "" { + c.SSHFileTransferMethod = "scp" + } + // Validation var errs []error if c.SSHUsername == "" { @@ -116,6 +121,12 @@ func (c *Config) prepareSSH(ctx *interpolate.Context) []error { } } + if c.SSHFileTransferMethod != "scp" && c.SSHFileTransferMethod != "sftp" { + errs = append(errs, fmt.Errorf( + "ssh_file_transfer_method ('%s') is invalid, valid methods: sftp, scp", + c.SSHFileTransferMethod)) + } + return errs } diff --git a/helper/communicator/step_connect_ssh.go b/helper/communicator/step_connect_ssh.go index 0d302f779..71a6d1a39 100644 --- a/helper/communicator/step_connect_ssh.go +++ b/helper/communicator/step_connect_ssh.go @@ -162,6 +162,7 @@ func (s *StepConnectSSH) waitForSSH(state multistep.StateBag, cancel <-chan stru SSHConfig: sshConfig, Pty: s.Config.SSHPty, DisableAgent: s.Config.SSHDisableAgent, + UseSftp: s.Config.SSHFileTransferMethod == "sftp", } log.Println("[INFO] Attempting SSH connection...") diff --git a/test/fixtures/provisioner-file/dir_no_trailing_sftp.json b/test/fixtures/provisioner-file/dir_no_trailing_sftp.json new file mode 100644 index 000000000..283751698 --- /dev/null +++ b/test/fixtures/provisioner-file/dir_no_trailing_sftp.json @@ -0,0 +1,23 @@ +{ + "builders": [{ + "type": "amazon-ebs", + "ami_name": "packer-test {{timestamp}}", + "instance_type": "t2.micro", + "region": "us-east-1", + "ssh_username": "ec2-user", + "ssh_file_transfer_method": "sftp", + "source_ami": "ami-8da458e6", + "tags": { + "packer-test": "true" + } + }], + + "provisioners": [{ + "type": "file", + "source": "dir", + "destination": "/tmp" + }, { + "type": "shell", + "inline": ["cat /tmp/dir/file.txt"] + }] +} diff --git a/test/fixtures/provisioner-file/dir_with_trailing_sftp.json b/test/fixtures/provisioner-file/dir_with_trailing_sftp.json new file mode 100644 index 000000000..f26365434 --- /dev/null +++ b/test/fixtures/provisioner-file/dir_with_trailing_sftp.json @@ -0,0 +1,23 @@ +{ + "builders": [{ + "type": "amazon-ebs", + "ami_name": "packer-test {{timestamp}}", + "instance_type": "t2.micro", + "region": "us-east-1", + "ssh_username": "ec2-user", + "ssh_file_transfer_method": "sftp", + "source_ami": "ami-8da458e6", + "tags": { + "packer-test": "true" + } + }], + + "provisioners": [{ + "type": "file", + "source": "dir/", + "destination": "/tmp" + }, { + "type": "shell", + "inline": ["cat /tmp/file.txt"] + }] +} diff --git a/test/fixtures/provisioner-file/file_sftp.json b/test/fixtures/provisioner-file/file_sftp.json new file mode 100644 index 000000000..19a8a98b8 --- /dev/null +++ b/test/fixtures/provisioner-file/file_sftp.json @@ -0,0 +1,23 @@ +{ + "builders": [{ + "type": "amazon-ebs", + "ami_name": "packer-test {{timestamp}}", + "instance_type": "t2.micro", + "region": "us-east-1", + "ssh_username": "ec2-user", + "ssh_file_transfer_method": "sftp", + "source_ami": "ami-8da458e6", + "tags": { + "packer-test": "true" + } + }], + + "provisioners": [{ + "type": "file", + "source": "file.txt", + "destination": "/tmp/file.txt" + }, { + "type": "shell", + "inline": ["cat /tmp/file.txt"] + }] +} diff --git a/test/provisioner_file.bats b/test/provisioner_file.bats index dafe7771a..5800d2d82 100755 --- a/test/provisioner_file.bats +++ b/test/provisioner_file.bats @@ -32,3 +32,21 @@ teardown() { [ "$status" -eq 0 ] [[ "$output" == *"337 miles"* ]] } + +@test "file provisioner: single file through sftp" { + run packer build $FIXTURE_ROOT/file_sftp.json + [ "$status" -eq 0 ] + [[ "$output" == *"24901 miles"* ]] +} + +@test "file provisioner: directory through sftp (no trailing slash)" { + run packer build $FIXTURE_ROOT/dir_no_trailing_sftp.json + [ "$status" -eq 0 ] + [[ "$output" == *"337 miles"* ]] +} + +@test "file provisioner: directory through sftp (with trailing slash)" { + run packer build $FIXTURE_ROOT/dir_with_trailing_sftp.json + [ "$status" -eq 0 ] + [[ "$output" == *"337 miles"* ]] +} diff --git a/website/source/docs/templates/communicator.html.md b/website/source/docs/templates/communicator.html.md index f38815309..b3e2eab06 100644 --- a/website/source/docs/templates/communicator.html.md +++ b/website/source/docs/templates/communicator.html.md @@ -94,6 +94,9 @@ The SSH communicator has the following options: * `ssh_bastion_private_key_file` (string) - A private key file to use to authenticate with the bastion host. + + * `ssh_file_transfer_method` (`scp` or `sftp`) - How to transfer files, Secure + copy (default) or SSH File Transfer Protocol. ## WinRM Communicator