diff --git a/common/download.go b/common/download.go index b5798b76c..d6711b6b4 100644 --- a/common/download.go +++ b/common/download.go @@ -89,7 +89,7 @@ func NewDownloadClient(c *DownloadConfig) *DownloadClient { // downloading it. type Downloader interface { Cancel() - Download(io.Writer, *url.URL) error + Download(*os.File, *url.URL) error Progress() uint Total() uint } @@ -99,6 +99,8 @@ func (d *DownloadClient) Cancel() { } func (d *DownloadClient) Get() (string, error) { + var f *os.File + // If we already have the file and it matches, then just return the target path. if verify, _ := d.VerifyChecksum(d.config.TargetPath); verify { log.Println("Initial checksum matched, no download needed.") @@ -131,7 +133,7 @@ func (d *DownloadClient) Get() (string, error) { } // Otherwise, download using the downloader. - f, err := os.Create(finalPath) + f, err = os.OpenFile(finalPath, os.O_RDWR|os.O_CREATE, os.FileMode(0666)) if err != nil { return "", err } @@ -195,9 +197,9 @@ func (*HTTPDownloader) Cancel() { // TODO(mitchellh): Implement } -func (d *HTTPDownloader) Download(dst io.Writer, src *url.URL) error { +func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { log.Printf("Starting download: %s", src.String()) - req, err := http.NewRequest("GET", src.String(), nil) + req, err := http.NewRequest("HEAD", src.String(), nil) if err != nil { return err } @@ -217,10 +219,15 @@ func (d *HTTPDownloader) Download(dst io.Writer, src *url.URL) error { return err } + req.Method = "GET" if resp.StatusCode != 200 { log.Printf( "Non-200 status code: %d. Getting error body.", resp.StatusCode) + resp, err := httpClient.Do(req) + if err != nil { + return err + } errorBody := new(bytes.Buffer) io.Copy(errorBody, resp.Body) return fmt.Errorf("HTTP error '%d'! Remote side responded:\n%s", @@ -228,6 +235,21 @@ func (d *HTTPDownloader) Download(dst io.Writer, src *url.URL) error { } d.progress = 0 + + if resp.Header.Get("Accept-Ranges") == "bytes" { + if fi, err := dst.Stat(); err == nil { + if _, err = dst.Seek(0, os.SEEK_END); err == nil { + req.Header.Set("Range", fmt.Sprintf("bytes=%d-", fi.Size())) + d.progress = uint(fi.Size()) + } + } + } + + resp, err = httpClient.Do(req) + if err != nil { + return err + } + d.total = uint(resp.ContentLength) var buffer [4096]byte