diff --git a/communicator/winrm/communicator.go b/communicator/winrm/communicator.go index 69f89194e..355548715 100644 --- a/communicator/winrm/communicator.go +++ b/communicator/winrm/communicator.go @@ -1,8 +1,10 @@ package winrm import ( + "encoding/base64" "fmt" "io" + "io/ioutil" "log" "os" "path/filepath" @@ -143,7 +145,20 @@ func (c *Communicator) UploadDir(dst string, src string, exclude []string) error } func (c *Communicator) Download(src string, dst io.Writer) error { - return fmt.Errorf("WinRM doesn't support download.") + endpoint := winrm.NewEndpoint(c.endpoint.Host, c.endpoint.Port, c.config.Https, c.config.Insecure, nil, nil, nil, c.config.Timeout) + client, err := winrm.NewClient(endpoint, c.config.Username, c.config.Password) + if err != nil { + return err + } + + encodeScript := `$file=[System.IO.File]::ReadAllBytes("%s"); Write-Output $([System.Convert]::ToBase64String($file))` + + base64DecodePipe := &Base64Pipe{w: dst} + + cmd := winrm.Powershell(fmt.Sprintf(encodeScript, src)) + _, err = client.Run(cmd, base64DecodePipe, ioutil.Discard) + + return err } func (c *Communicator) DownloadDir(src string, dst string, exclude []string) error { @@ -164,3 +179,33 @@ func (c *Communicator) newCopyClient() (*winrmcp.Winrmcp, error) { TransportDecorator: c.config.TransportDecorator, }) } + +type Base64Pipe struct { + w io.Writer // underlying writer (file, buffer) +} + +func (d *Base64Pipe) ReadFrom(r io.Reader) (int64, error) { + b, err := ioutil.ReadAll(r) + if err != nil { + return 0, err + } + + var i int + i, err = d.Write(b) + + if err != nil { + return 0, err + } + + return int64(i), err +} + +func (d *Base64Pipe) Write(p []byte) (int, error) { + dst := make([]byte, base64.StdEncoding.DecodedLen(len(p))) + + if _, err := base64.StdEncoding.Decode(dst, p); err != nil { + return 0, err + } + + return d.w.Write(dst) +} diff --git a/communicator/winrm/communicator_test.go b/communicator/winrm/communicator_test.go index 5f61680f9..b80f8316b 100644 --- a/communicator/winrm/communicator_test.go +++ b/communicator/winrm/communicator_test.go @@ -3,6 +3,7 @@ package winrm import ( "bytes" "io" + "strings" "testing" "time" @@ -10,6 +11,9 @@ import ( "github.com/hashicorp/packer/packer" ) +const PAYLOAD = "stuff" +const BASE64_ENCODED_PAYLOAD = "c3R1ZmY=" + func newMockWinRMServer(t *testing.T) *winrmtest.Remote { wrm := winrmtest.NewRemote() @@ -26,9 +30,16 @@ func newMockWinRMServer(t *testing.T) *winrmtest.Remote { return 0 }) + wrm.CommandFunc( + winrmtest.MatchPattern(`^echo `+BASE64_ENCODED_PAYLOAD+` >> ".*"$`), + func(out, err io.Writer) int { + return 0 + }) + wrm.CommandFunc( winrmtest.MatchPattern(`^powershell.exe -EncodedCommand .*$`), func(out, err io.Writer) int { + out.Write([]byte(BASE64_ENCODED_PAYLOAD)) return 0 }) @@ -86,9 +97,21 @@ func TestUpload(t *testing.T) { if err != nil { t.Fatalf("error creating communicator: %s", err) } - - err = c.Upload("C:/Temp/packer.cmd", bytes.NewReader([]byte("something")), nil) + file := "C:/Temp/packer.cmd" + err = c.Upload(file, strings.NewReader(PAYLOAD), nil) if err != nil { t.Fatalf("error uploading file: %s", err) } + + dest := new(bytes.Buffer) + err = c.Download(file, dest) + if err != nil { + t.Fatalf("error downloading file: %s", err) + } + downloadedPayload := strings.TrimRight(dest.String(), "\x00") + + if downloadedPayload != PAYLOAD { + t.Fatalf("files are not equal: expected [%s] length: %v, got [%s] length %v", PAYLOAD, len(PAYLOAD), downloadedPayload, len(downloadedPayload)) + } + } diff --git a/vendor/github.com/masterzen/winrm/client.go b/vendor/github.com/masterzen/winrm/client.go index 03755ff4e..732dd61cb 100644 --- a/vendor/github.com/masterzen/winrm/client.go +++ b/vendor/github.com/masterzen/winrm/client.go @@ -5,6 +5,7 @@ import ( "crypto/x509" "fmt" "io" + "sync" "github.com/masterzen/winrm/soap" ) @@ -114,10 +115,21 @@ func (c *Client) Run(command string, stdout io.Writer, stderr io.Writer) (int, e return 1, err } - go io.Copy(stdout, cmd.Stdout) - go io.Copy(stderr, cmd.Stderr) + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + io.Copy(stdout, cmd.Stdout) + }() + + go func() { + defer wg.Done() + io.Copy(stderr, cmd.Stderr) + }() cmd.Wait() + wg.Wait() return cmd.ExitCode(), cmd.err }