Add sftp file transfer support

Adds a new config option: "ssh_file_transfer_method", which can be set to "scp"
or "sftp" (defaults to "scp")
This commit is contained in:
Alfonso Acosta 2015-07-26 23:39:56 +00:00
parent ef873ba210
commit a59c82d7a6
8 changed files with 389 additions and 106 deletions

View File

@ -16,6 +16,7 @@ import (
"time"
"github.com/mitchellh/packer/packer"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
@ -51,6 +52,9 @@ type Config struct {
// HandshakeTimeout limits the amount of time we'll wait to handshake before
// saying the connection failed.
HandshakeTimeout time.Duration
// UseSftp, if true, sftp will be used instead of scp for file transfers
UseSftp bool
}
// Creates a new packer.Communicator implementation over SSH. This takes
@ -136,107 +140,28 @@ func (c *comm) Start(cmd *packer.RemoteCmd) (err error) {
}
func (c *comm) Upload(path string, input io.Reader, fi *os.FileInfo) error {
// The target directory and file for talking the SCP protocol
target_dir := filepath.Dir(path)
target_file := filepath.Base(path)
// On windows, filepath.Dir uses backslash seperators (ie. "\tmp").
// This does not work when the target host is unix. Switch to forward slash
// which works for unix and windows
target_dir = filepath.ToSlash(target_dir)
scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
return scpUploadFile(target_file, input, w, stdoutR, fi)
if c.config.UseSftp {
return c.sftpUploadSession(path, input, fi)
} else {
return c.scpUploadSession(path, input, fi)
}
return c.scpSession("scp -vt "+target_dir, scpFunc)
}
func (c *comm) UploadDir(dst string, src string, excl []string) error {
log.Printf("Upload dir '%s' to '%s'", src, dst)
scpFunc := func(w io.Writer, r *bufio.Reader) error {
uploadEntries := func() error {
f, err := os.Open(src)
if err != nil {
return err
}
defer f.Close()
entries, err := f.Readdir(-1)
if err != nil {
return err
}
return scpUploadDir(src, entries, w, r)
}
if src[len(src)-1] != '/' {
log.Printf("No trailing slash, creating the source directory name")
fi, err := os.Stat(src)
if err != nil {
return err
}
return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries, fi)
} else {
// Trailing slash, so only upload the contents
return uploadEntries()
}
if c.config.UseSftp {
return c.sftpUploadDirSession(dst, src, excl)
} else {
return c.scpUploadDirSession(dst, src, excl)
}
return c.scpSession("scp -rvt "+dst, scpFunc)
}
func (c *comm) Download(path string, output io.Writer) error {
scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
fmt.Fprint(w, "\x00")
// read file info
fi, err := stdoutR.ReadString('\n')
if err != nil {
return err
}
if len(fi) < 0 {
return fmt.Errorf("empty response from server")
}
switch fi[0] {
case '\x01', '\x02':
return fmt.Errorf("%s", fi[1:len(fi)])
case 'C':
case 'D':
return fmt.Errorf("remote file is directory")
default:
return fmt.Errorf("unexpected server response (%x)", fi[0])
}
var mode string
var size int64
n, err := fmt.Sscanf(fi, "%6s %d ", &mode, &size)
if err != nil || n != 2 {
return fmt.Errorf("can't parse server response (%s)", fi)
}
if size < 0 {
return fmt.Errorf("negative file size")
}
fmt.Fprint(w, "\x00")
if _, err := io.CopyN(output, stdoutR, size); err != nil {
return err
}
fmt.Fprint(w, "\x00")
if err := checkSCPStatus(stdoutR); err != nil {
return err
}
return nil
if c.config.UseSftp {
return c.sftpDownloadSession(path, output)
} else {
return c.scpDownloadSession(path, output)
}
return c.scpSession("scp -vf "+strconv.Quote(path), scpFunc)
}
func (c *comm) newSession() (session *ssh.Session, err error) {
@ -385,6 +310,262 @@ func (c *comm) connectToAgent() {
return
}
func (c *comm) sftpUploadSession(path string, input io.Reader, fi *os.FileInfo) error {
sftpFunc := func(client *sftp.Client) error {
return sftpUploadFile(path, input, client, fi)
}
return c.sftpSession(sftpFunc)
}
func sftpUploadFile(path string, input io.Reader, client *sftp.Client, fi *os.FileInfo) error {
log.Printf("[DEBUG] sftp: uploading %s", path)
f, err := client.Create(path)
if err != nil {
return err
}
defer f.Close()
if _, err = io.Copy(f, input); err != nil {
return err
}
if fi != nil && (*fi).Mode().IsRegular() {
mode := (*fi).Mode().Perm()
err = client.Chmod(path, mode)
if err != nil {
return err
}
}
return nil
}
func (c *comm) sftpUploadDirSession(dst string, src string, excl []string) error {
sftpFunc := func(client *sftp.Client) error {
rootDst := dst
if src[len(src)-1] != '/' {
log.Printf("No trailing slash, creating the source directory name")
rootDst = filepath.Join(dst, filepath.Base(src))
}
walkFunc := func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// Calculate the final destination using the
// base source and root destination
relSrc, err := filepath.Rel(src, path)
if err != nil {
return err
}
finalDst := filepath.Join(rootDst, relSrc)
// In Windows, Join uses backslashes which we don't want to get
// to the sftp server
finalDst = filepath.ToSlash(finalDst)
// Skip the creation of the target destination directory since
// it should exist and we might not even own it
if finalDst == dst {
return nil
}
return sftpVisitFile(finalDst, path, info, client)
}
return filepath.Walk(src, walkFunc)
}
return c.sftpSession(sftpFunc)
}
func sftpMkdir(path string, client *sftp.Client, fi os.FileInfo) error {
log.Printf("[DEBUG] sftp: creating dir %s", path)
if err := client.Mkdir(path); err != nil {
// Do not consider it an error if the directory existed
remoteFi, fiErr := client.Lstat(path)
if fiErr != nil || !remoteFi.IsDir() {
return err
}
}
mode := fi.Mode().Perm()
if err := client.Chmod(path, mode); err != nil {
return err
}
return nil
}
func sftpVisitFile(dst string, src string, fi os.FileInfo, client *sftp.Client) error {
if !fi.IsDir() {
f, err := os.Open(src)
if err != nil {
return err
}
defer f.Close()
return sftpUploadFile(dst, f, client, &fi)
} else {
err := sftpMkdir(dst, client, fi)
return err
}
}
func (c *comm) sftpDownloadSession(path string, output io.Writer) error {
sftpFunc := func(client *sftp.Client) error {
f, err := client.Open(path)
if err != nil {
return err
}
defer f.Close()
if _, err = io.Copy(output, f); err != nil {
return err
}
return nil
}
return c.sftpSession(sftpFunc)
}
func (c *comm) sftpSession(f func(*sftp.Client) error) error {
client, err := c.newSftpClient()
if err != nil {
return err
}
defer client.Close()
return f(client)
}
func (c *comm) newSftpClient() (*sftp.Client, error) {
session, err := c.newSession()
if err != nil {
return nil, err
}
if err := session.RequestSubsystem("sftp"); err != nil {
return nil, err
}
pw, err := session.StdinPipe()
if err != nil {
return nil, err
}
pr, err := session.StdoutPipe()
if err != nil {
return nil, err
}
return sftp.NewClientPipe(pr, pw)
}
func (c *comm) scpUploadSession(path string, input io.Reader, fi *os.FileInfo) error {
// The target directory and file for talking the SCP protocol
target_dir := filepath.Dir(path)
target_file := filepath.Base(path)
// On windows, filepath.Dir uses backslash seperators (ie. "\tmp").
// This does not work when the target host is unix. Switch to forward slash
// which works for unix and windows
target_dir = filepath.ToSlash(target_dir)
scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
return scpUploadFile(target_file, input, w, stdoutR, fi)
}
return c.scpSession("scp -vt "+target_dir, scpFunc)
}
func (c *comm) scpUploadDirSession(dst string, src string, excl []string) error {
scpFunc := func(w io.Writer, r *bufio.Reader) error {
uploadEntries := func() error {
f, err := os.Open(src)
if err != nil {
return err
}
defer f.Close()
entries, err := f.Readdir(-1)
if err != nil {
return err
}
return scpUploadDir(src, entries, w, r)
}
if src[len(src)-1] != '/' {
log.Printf("No trailing slash, creating the source directory name")
fi, err := os.Stat(src)
if err != nil {
return err
}
return scpUploadDirProtocol(filepath.Base(src), w, r, uploadEntries, fi)
} else {
// Trailing slash, so only upload the contents
return uploadEntries()
}
}
return c.scpSession("scp -rvt "+dst, scpFunc)
}
func (c *comm) scpDownloadSession(path string, output io.Writer) error {
scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error {
fmt.Fprint(w, "\x00")
// read file info
fi, err := stdoutR.ReadString('\n')
if err != nil {
return err
}
if len(fi) < 0 {
return fmt.Errorf("empty response from server")
}
switch fi[0] {
case '\x01', '\x02':
return fmt.Errorf("%s", fi[1:len(fi)])
case 'C':
case 'D':
return fmt.Errorf("remote file is directory")
default:
return fmt.Errorf("unexpected server response (%x)", fi[0])
}
var mode string
var size int64
n, err := fmt.Sscanf(fi, "%6s %d ", &mode, &size)
if err != nil || n != 2 {
return fmt.Errorf("can't parse server response (%s)", fi)
}
if size < 0 {
return fmt.Errorf("negative file size")
}
fmt.Fprint(w, "\x00")
if _, err := io.CopyN(output, stdoutR, size); err != nil {
return err
}
fmt.Fprint(w, "\x00")
if err := checkSCPStatus(stdoutR); err != nil {
return err
}
return nil
}
return c.scpSession("scp -vf "+strconv.Quote(path), scpFunc)
}
func (c *comm) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error {
session, err := c.newSession()
if err != nil {
@ -533,7 +714,7 @@ func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader, fi *
// Start the protocol
perms := fmt.Sprintf("C%04o", mode)
log.Printf("[DEBUG] Uploading %s: perms=%s size=%d", dst, perms, size)
log.Printf("[DEBUG] scp: Uploading %s: perms=%s size=%d", dst, perms, size)
fmt.Fprintln(w, perms, size, dst)
if err := checkSCPStatus(r); err != nil {

View File

@ -15,20 +15,21 @@ type Config struct {
Type string `mapstructure:"communicator"`
// SSH
SSHHost string `mapstructure:"ssh_host"`
SSHPort int `mapstructure:"ssh_port"`
SSHUsername string `mapstructure:"ssh_username"`
SSHPassword string `mapstructure:"ssh_password"`
SSHPrivateKey string `mapstructure:"ssh_private_key_file"`
SSHPty bool `mapstructure:"ssh_pty"`
SSHTimeout time.Duration `mapstructure:"ssh_timeout"`
SSHDisableAgent bool `mapstructure:"ssh_disable_agent"`
SSHHandshakeAttempts int `mapstructure:"ssh_handshake_attempts"`
SSHBastionHost string `mapstructure:"ssh_bastion_host"`
SSHBastionPort int `mapstructure:"ssh_bastion_port"`
SSHBastionUsername string `mapstructure:"ssh_bastion_username"`
SSHBastionPassword string `mapstructure:"ssh_bastion_password"`
SSHBastionPrivateKey string `mapstructure:"ssh_bastion_private_key_file"`
SSHHost string `mapstructure:"ssh_host"`
SSHPort int `mapstructure:"ssh_port"`
SSHUsername string `mapstructure:"ssh_username"`
SSHPassword string `mapstructure:"ssh_password"`
SSHPrivateKey string `mapstructure:"ssh_private_key_file"`
SSHPty bool `mapstructure:"ssh_pty"`
SSHTimeout time.Duration `mapstructure:"ssh_timeout"`
SSHDisableAgent bool `mapstructure:"ssh_disable_agent"`
SSHHandshakeAttempts int `mapstructure:"ssh_handshake_attempts"`
SSHBastionHost string `mapstructure:"ssh_bastion_host"`
SSHBastionPort int `mapstructure:"ssh_bastion_port"`
SSHBastionUsername string `mapstructure:"ssh_bastion_username"`
SSHBastionPassword string `mapstructure:"ssh_bastion_password"`
SSHBastionPrivateKey string `mapstructure:"ssh_bastion_private_key_file"`
SSHFileTransferMethod string `mapstructure:"ssh_file_transfer_method"`
// WinRM
WinRMUser string `mapstructure:"winrm_username"`
@ -93,6 +94,10 @@ func (c *Config) prepareSSH(ctx *interpolate.Context) []error {
}
}
if c.SSHFileTransferMethod == "" {
c.SSHFileTransferMethod = "scp"
}
// Validation
var errs []error
if c.SSHUsername == "" {
@ -116,6 +121,12 @@ func (c *Config) prepareSSH(ctx *interpolate.Context) []error {
}
}
if c.SSHFileTransferMethod != "scp" && c.SSHFileTransferMethod != "sftp" {
errs = append(errs, fmt.Errorf(
"ssh_file_transfer_method ('%s') is invalid, valid methods: sftp, scp",
c.SSHFileTransferMethod))
}
return errs
}

View File

@ -162,6 +162,7 @@ func (s *StepConnectSSH) waitForSSH(state multistep.StateBag, cancel <-chan stru
SSHConfig: sshConfig,
Pty: s.Config.SSHPty,
DisableAgent: s.Config.SSHDisableAgent,
UseSftp: s.Config.SSHFileTransferMethod == "sftp",
}
log.Println("[INFO] Attempting SSH connection...")

View File

@ -0,0 +1,23 @@
{
"builders": [{
"type": "amazon-ebs",
"ami_name": "packer-test {{timestamp}}",
"instance_type": "t2.micro",
"region": "us-east-1",
"ssh_username": "ec2-user",
"ssh_file_transfer_method": "sftp",
"source_ami": "ami-8da458e6",
"tags": {
"packer-test": "true"
}
}],
"provisioners": [{
"type": "file",
"source": "dir",
"destination": "/tmp"
}, {
"type": "shell",
"inline": ["cat /tmp/dir/file.txt"]
}]
}

View File

@ -0,0 +1,23 @@
{
"builders": [{
"type": "amazon-ebs",
"ami_name": "packer-test {{timestamp}}",
"instance_type": "t2.micro",
"region": "us-east-1",
"ssh_username": "ec2-user",
"ssh_file_transfer_method": "sftp",
"source_ami": "ami-8da458e6",
"tags": {
"packer-test": "true"
}
}],
"provisioners": [{
"type": "file",
"source": "dir/",
"destination": "/tmp"
}, {
"type": "shell",
"inline": ["cat /tmp/file.txt"]
}]
}

View File

@ -0,0 +1,23 @@
{
"builders": [{
"type": "amazon-ebs",
"ami_name": "packer-test {{timestamp}}",
"instance_type": "t2.micro",
"region": "us-east-1",
"ssh_username": "ec2-user",
"ssh_file_transfer_method": "sftp",
"source_ami": "ami-8da458e6",
"tags": {
"packer-test": "true"
}
}],
"provisioners": [{
"type": "file",
"source": "file.txt",
"destination": "/tmp/file.txt"
}, {
"type": "shell",
"inline": ["cat /tmp/file.txt"]
}]
}

View File

@ -32,3 +32,21 @@ teardown() {
[ "$status" -eq 0 ]
[[ "$output" == *"337 miles"* ]]
}
@test "file provisioner: single file through sftp" {
run packer build $FIXTURE_ROOT/file_sftp.json
[ "$status" -eq 0 ]
[[ "$output" == *"24901 miles"* ]]
}
@test "file provisioner: directory through sftp (no trailing slash)" {
run packer build $FIXTURE_ROOT/dir_no_trailing_sftp.json
[ "$status" -eq 0 ]
[[ "$output" == *"337 miles"* ]]
}
@test "file provisioner: directory through sftp (with trailing slash)" {
run packer build $FIXTURE_ROOT/dir_with_trailing_sftp.json
[ "$status" -eq 0 ]
[[ "$output" == *"337 miles"* ]]
}

View File

@ -94,6 +94,9 @@ The SSH communicator has the following options:
* `ssh_bastion_private_key_file` (string) - A private key file to use
to authenticate with the bastion host.
* `ssh_file_transfer_method` (`scp` or `sftp`) - How to transfer files, Secure
copy (default) or SSH File Transfer Protocol.
## WinRM Communicator