communicator/ssh: buffer file on disk to read length [GH-561]

This commit is contained in:
Mitchell Hashimoto 2013-11-02 12:07:45 +04:00
parent 7f639d89b6
commit 254653475e
2 changed files with 34 additions and 12 deletions

View File

@ -13,6 +13,8 @@ BUG FIXES:
* builder/openstack: Properly scrub password from logs [GH-554] * builder/openstack: Properly scrub password from logs [GH-554]
* common/uuid: Use cryptographically secure PRNG when generating * common/uuid: Use cryptographically secure PRNG when generating
UUIDs. [GH-552] UUIDs. [GH-552]
* communicator/ssh: File uploads that exceed the size of memory no longer
cause crashes. [GH-561]
## 0.3.10 (October 20, 2013) ## 0.3.10 (October 20, 2013)

View File

@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"github.com/mitchellh/packer/packer" "github.com/mitchellh/packer/packer"
"io" "io"
"io/ioutil"
"log" "log"
"net" "net"
"os" "os"
@ -362,30 +363,49 @@ func checkSCPStatus(r *bufio.Reader) error {
} }
func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader) error { func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader) error {
// Determine the length of the upload content by copying it // Create a temporary file where we can copy the contents of the src
// into an in-memory buffer. Note that this means what we upload // so that we can determine the length, since SCP is length-prefixed.
// must fit into memory. tf, err := ioutil.TempFile("", "packer-upload")
log.Println("Copying input data into in-memory buffer so we can get the length") if err != nil {
inputBuf := new(bytes.Buffer) return fmt.Errorf("Error creating temporary file for upload: %s", err)
if _, err := io.Copy(inputBuf, src); err != nil { }
defer os.Remove(tf.Name())
defer tf.Close()
log.Println("Copying input data into temporary file so we can read the length")
if _, err := io.Copy(tf, src); err != nil {
return err return err
} }
// Sync the file so that the contents are definitely on disk, then
// read the length of it.
if err := tf.Sync(); err != nil {
return fmt.Errorf("Error creating temporary file for upload: %s", err)
}
// Seek the file to the beginning so we can re-read all of it
if _, err := tf.Seek(0, 0); err != nil {
return fmt.Errorf("Error creating temporary file for upload: %s", err)
}
fi, err := tf.Stat()
if err != nil {
return fmt.Errorf("Error creating temporary file for upload: %s", err)
}
// Start the protocol // Start the protocol
log.Println("Beginning file upload...") log.Println("Beginning file upload...")
fmt.Fprintln(w, "C0644", inputBuf.Len(), dst) fmt.Fprintln(w, "C0644", fi.Size(), dst)
err := checkSCPStatus(r) if err := checkSCPStatus(r); err != nil {
if err != nil {
return err return err
} }
if _, err := io.Copy(w, inputBuf); err != nil { if _, err := io.Copy(w, tf); err != nil {
return err return err
} }
fmt.Fprint(w, "\x00") fmt.Fprint(w, "\x00")
err = checkSCPStatus(r) if err := checkSCPStatus(r); err != nil {
if err != nil {
return err return err
} }