packer-cn/common/download.go

732 lines
16 KiB
Go
Raw Normal View History

2013-06-12 20:41:44 -04:00
package common
import (
"bytes"
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
2013-06-12 20:41:44 -04:00
"encoding/hex"
"errors"
"fmt"
"hash"
"log"
2013-06-12 20:41:44 -04:00
"net/url"
"os"
"path"
2017-03-01 15:37:05 -05:00
"runtime"
"strings"
2013-06-12 20:41:44 -04:00
)
// imports related to each Downloader implementation
import (
2017-03-01 15:37:05 -05:00
"github.com/jlaffaye/ftp"
"gopkg.in/cheggaaa/pb.v1"
"io"
"net/http"
2017-03-01 15:37:05 -05:00
"path/filepath"
)
2013-06-12 20:41:44 -04:00
// DownloadConfig is the configuration given to instantiate a new
// download instance. Once a configuration is used to instantiate
// a download client, it must not be modified.
type DownloadConfig struct {
// The source URL in the form of a string.
Url string
// This is the path to download the file to.
TargetPath string
// DownloaderMap maps a schema to a Download.
DownloaderMap map[string]Downloader
// If true, this will copy even a local file to the target
// location. If false, then it will "download" the file by just
// returning the local path to the file.
CopyFile bool
// The hashing implementation to use to checksum the downloaded file.
Hash hash.Hash
// The checksum for the downloaded file. The hash implementation configuration
// for the downloader will be used to verify with this checksum after
// it is downloaded.
Checksum []byte
// What to use for the user agent for HTTP requests. If set to "", use the
// default user agent provided by Go.
UserAgent string
2013-06-12 20:41:44 -04:00
}
// A DownloadClient helps download, verify checksums, etc.
type DownloadClient struct {
config *DownloadConfig
downloader Downloader
}
// HashForType returns the Hash implementation for the given string
// type, or nil if the type is not supported.
func HashForType(t string) hash.Hash {
switch t {
case "md5":
return md5.New()
case "sha1":
return sha1.New()
case "sha256":
return sha256.New()
case "sha512":
return sha512.New()
default:
return nil
}
}
2013-06-12 20:41:44 -04:00
// NewDownloadClient returns a new DownloadClient for the given
// configuration.
func NewDownloadClient(c *DownloadConfig, bar *pb.ProgressBar) *DownloadClient {
const mtu = 1500 /* ethernet */ - 20 /* ipv4 */ - 20 /* tcp */
// If no custom progress-bar was specified, then create a default one.
if bar == nil {
bar = pb.New64(0)
}
// Create downloader map if it hasn't been specified already.
2013-06-12 20:41:44 -04:00
if c.DownloaderMap == nil {
c.DownloaderMap = map[string]Downloader{
"file": &FileDownloader{progress: bar, bufferSize: nil},
"ftp": &FTPDownloader{progress: bar, userInfo: url.UserPassword("anonymous", "anonymous@"), mtu: mtu},
"http": &HTTPDownloader{progress: bar, userAgent: c.UserAgent},
"https": &HTTPDownloader{progress: bar, userAgent: c.UserAgent},
"smb": &SMBDownloader{progress: bar, bufferSize: nil},
2013-06-12 20:41:44 -04:00
}
}
return &DownloadClient{config: c}
}
// A downloader implements the ability to transfer a file, and cancel or resume
// it.
2013-06-12 20:41:44 -04:00
type Downloader interface {
Resume()
2013-06-12 20:41:44 -04:00
Cancel()
Progress() uint64
Total() uint64
}
// A LocalDownloader is responsible for converting a uri to a local path
// that the platform can open directly.
type LocalDownloader interface {
2017-03-01 15:37:05 -05:00
toPath(string, url.URL) (string, error)
}
// A RemoteDownloader is responsible for actually taking a remote URL and
// downloading it.
type RemoteDownloader interface {
Download(*os.File, *url.URL) error
2013-06-12 20:41:44 -04:00
}
func (d *DownloadClient) Cancel() {
// TODO(mitchellh): Implement
}
func (d *DownloadClient) Get() (string, error) {
// If we already have the file and it matches, then just return the target path.
if verify, _ := d.VerifyChecksum(d.config.TargetPath); verify {
2015-08-19 16:15:23 -04:00
log.Println("[DEBUG] Initial checksum matched, no download needed.")
2013-06-12 20:41:44 -04:00
return d.config.TargetPath, nil
}
/* parse the configuration url into a net/url object */
u, err := url.Parse(d.config.Url)
2017-03-01 15:37:05 -05:00
if err != nil {
return "", err
}
log.Printf("Parsed URL: %#v", u)
/* use the current working directory as the base for relative uri's */
2017-03-01 15:37:05 -05:00
cwd, err := os.Getwd()
if err != nil {
return "", err
}
// Determine which is the correct downloader to use
2013-06-12 20:41:44 -04:00
var finalPath string
2013-08-03 16:34:48 -04:00
var ok bool
d.downloader, ok = d.config.DownloaderMap[u.Scheme]
if !ok {
return "", fmt.Errorf("No downloader for scheme: %s", u.Scheme)
}
2017-03-01 15:37:05 -05:00
remote, ok := d.downloader.(RemoteDownloader)
if !ok {
return "", fmt.Errorf("Unable to treat uri scheme %s as a Downloader. : %T", u.Scheme, d.downloader)
}
2017-03-01 15:37:05 -05:00
local, ok := d.downloader.(LocalDownloader)
if !ok && !d.config.CopyFile {
d.config.CopyFile = true
}
// If we're copying the file, then just use the actual downloader
if d.config.CopyFile {
var f *os.File
finalPath = d.config.TargetPath
2013-06-12 20:41:44 -04:00
f, err = os.OpenFile(finalPath, os.O_RDWR|os.O_CREATE, os.FileMode(0666))
2017-03-01 15:37:05 -05:00
if err != nil {
return "", err
}
2013-06-12 20:41:44 -04:00
log.Printf("[DEBUG] Downloading: %s", u.String())
err = remote.Download(f, u)
2015-06-22 15:17:29 -04:00
f.Close()
2017-03-01 15:37:05 -05:00
if err != nil {
return "", err
}
2017-03-01 15:37:05 -05:00
// Otherwise if our Downloader is a LocalDownloader we can just use the
// path after transforming it.
} else {
2017-03-01 15:37:05 -05:00
finalPath, err = local.toPath(cwd, *u)
if err != nil {
return "", err
}
log.Printf("[DEBUG] Using local file: %s", finalPath)
2013-06-12 20:41:44 -04:00
}
if d.config.Hash != nil {
var verify bool
verify, err = d.VerifyChecksum(finalPath)
if err == nil && !verify {
// Only delete the file if we made a copy or downloaded it
2017-03-01 15:37:05 -05:00
if d.config.CopyFile {
os.Remove(finalPath)
}
2015-06-22 15:17:29 -04:00
err = fmt.Errorf(
"checksums didn't match expected: %s",
hex.EncodeToString(d.config.Checksum))
2013-06-12 20:41:44 -04:00
}
}
return finalPath, err
}
// VerifyChecksum tests that the path matches the checksum for the
// download.
func (d *DownloadClient) VerifyChecksum(path string) (bool, error) {
if d.config.Checksum == nil || d.config.Hash == nil {
return false, errors.New("Checksum or Hash isn't set on download.")
}
f, err := os.Open(path)
if err != nil {
return false, err
}
defer f.Close()
log.Printf("Verifying checksum of %s", path)
2013-06-12 20:41:44 -04:00
d.config.Hash.Reset()
io.Copy(d.config.Hash, f)
2017-03-28 21:29:55 -04:00
return bytes.Equal(d.config.Hash.Sum(nil), d.config.Checksum), nil
2013-06-12 20:41:44 -04:00
}
// HTTPDownloader is an implementation of Downloader that downloads
// files over HTTP.
type HTTPDownloader struct {
current uint64
total uint64
userAgent string
progress *pb.ProgressBar
2013-06-12 20:41:44 -04:00
}
func (d *HTTPDownloader) Cancel() {
// TODO(mitchellh): Implement
}
func (d *HTTPDownloader) Resume() {
2013-06-12 20:41:44 -04:00
// TODO(mitchellh): Implement
}
func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
log.Printf("Starting download over HTTP: %s", src.String())
2015-06-22 15:14:35 -04:00
// Seek to the beginning by default
if _, err := dst.Seek(0, 0); err != nil {
return err
}
// Reset our progress
d.current = 0
2015-06-22 15:14:35 -04:00
// Make the request. We first make a HEAD request so we can check
// if the server supports range queries. If the server/URL doesn't
// support HEAD requests, we just fall back to GET.
req, err := http.NewRequest("HEAD", src.String(), nil)
if err != nil {
return err
}
if d.userAgent != "" {
req.Header.Set("User-Agent", d.userAgent)
}
httpClient := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
},
}
resp, err := httpClient.Do(req)
2015-06-22 15:14:35 -04:00
if err == nil && (resp.StatusCode >= 200 && resp.StatusCode < 300) {
// If the HEAD request succeeded, then attempt to set the range
// query if we can.
if resp.Header.Get("Accept-Ranges") == "bytes" {
if fi, err := dst.Stat(); err == nil {
if _, err = dst.Seek(0, os.SEEK_END); err == nil {
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", fi.Size()))
d.current = uint64(fi.Size())
2015-06-22 15:14:35 -04:00
}
}
}
}
2015-06-22 15:14:35 -04:00
// Set the request to GET now, and redo the query to download
req.Method = "GET"
resp, err = httpClient.Do(req)
if err != nil {
return err
}
d.total = d.current + uint64(resp.ContentLength)
d.progress.Total = int64(d.total)
progressBar := d.progress.Start()
progressBar.Set64(int64(d.current))
2013-06-12 20:41:44 -04:00
var buffer [4096]byte
for {
n, err := resp.Body.Read(buffer[:])
if err != nil && err != io.EOF {
return err
}
d.current += uint64(n)
progressBar.Set64(int64(d.current))
2013-06-12 20:41:44 -04:00
if _, werr := dst.Write(buffer[:n]); werr != nil {
return werr
}
if err == io.EOF {
break
}
}
progressBar.Finish()
2013-06-12 20:41:44 -04:00
return nil
}
func (d *HTTPDownloader) Progress() uint64 {
return d.current
2013-06-12 20:41:44 -04:00
}
func (d *HTTPDownloader) Total() uint64 {
2013-06-12 20:41:44 -04:00
return d.total
}
// FTPDownloader is an implementation of Downloader that downloads
// files over FTP.
type FTPDownloader struct {
userInfo *url.Userinfo
2017-03-01 15:37:05 -05:00
mtu uint
active bool
current uint64
total uint64
progress *pb.ProgressBar
}
func (d *FTPDownloader) Progress() uint64 {
return d.current
}
func (d *FTPDownloader) Total() uint64 {
return d.total
}
func (d *FTPDownloader) Cancel() {
d.active = false
}
func (d *FTPDownloader) Resume() {
// TODO: Implement
}
func (d *FTPDownloader) Download(dst *os.File, src *url.URL) error {
var userinfo *url.Userinfo
userinfo = d.userInfo
d.active = false
// check the uri is correct
if src == nil || src.Scheme != "ftp" {
return fmt.Errorf("Unexpected uri scheme: %s", src.Scheme)
}
uri := src
// connect to ftp server
var cli *ftp.ServerConn
log.Printf("Starting download over FTP: %s : %s\n", uri.Host, uri.Path)
2017-03-01 15:37:05 -05:00
cli, err := ftp.Dial(uri.Host)
if err != nil {
return nil
}
defer cli.Quit()
// handle authentication
2017-03-01 15:37:05 -05:00
if uri.User != nil {
userinfo = uri.User
}
2017-03-01 15:37:05 -05:00
pass, ok := userinfo.Password()
if !ok {
pass = "ftp@"
}
log.Printf("Authenticating to FTP server: %s : %s\n", userinfo.Username(), pass)
err = cli.Login(userinfo.Username(), pass)
2017-03-01 15:37:05 -05:00
if err != nil {
return err
}
// locate specified path
p := path.Dir(uri.Path)
log.Printf("Changing to FTP directory : %s\n", p)
err = cli.ChangeDir(p)
2017-03-01 15:37:05 -05:00
if err != nil {
return nil
}
2017-03-01 15:37:05 -05:00
curpath, err := cli.CurrentDir()
if err != nil {
return err
}
log.Printf("Current FTP directory : %s\n", curpath)
// collect stats about the specified file
var name string
var entry *ftp.Entry
2017-03-01 15:37:05 -05:00
_, name = path.Split(uri.Path)
entry = nil
2017-03-01 15:37:05 -05:00
entries, err := cli.List(curpath)
for _, e := range entries {
if e.Type == ftp.EntryTypeFile && e.Name == name {
entry = e
break
}
}
if entry == nil {
return fmt.Errorf("Unable to find file: %s", uri.Path)
}
log.Printf("Found file : %s : %v bytes\n", entry.Name, entry.Size)
d.current = 0
d.total = entry.Size
d.progress.Total = int64(d.total)
progressBar := d.progress.Start()
// download specified file
d.active = true
reader, err := cli.RetrFrom(uri.Path, d.current)
2017-03-01 15:37:05 -05:00
if err != nil {
return nil
}
// do it in a goro so that if someone wants to cancel it, they can
errch := make(chan error)
go func(d *FTPDownloader, r io.Reader, w io.Writer, e chan error) {
2017-03-01 15:37:05 -05:00
for d.active {
n, err := io.CopyN(w, r, int64(d.mtu))
if err != nil {
break
}
d.current += uint64(n)
progressBar.Set64(int64(d.current))
}
d.active = false
e <- err
}(d, reader, dst, errch)
// spin until it's done
err = <-errch
progressBar.Finish()
reader.Close()
if err == nil && d.current != d.total {
err = fmt.Errorf("FTP total transfer size was %d when %d was expected", d.current, d.total)
}
// log out
cli.Logout()
return err
}
// FileDownloader is an implementation of Downloader that downloads
// files using the regular filesystem.
type FileDownloader struct {
bufferSize *uint
active bool
current uint64
total uint64
progress *pb.ProgressBar
}
func (d *FileDownloader) Progress() uint64 {
return d.current
}
func (d *FileDownloader) Total() uint64 {
return d.total
}
func (d *FileDownloader) Cancel() {
d.active = false
}
func (d *FileDownloader) Resume() {
// TODO: Implement
}
2017-03-01 15:37:05 -05:00
func (d *FileDownloader) toPath(base string, uri url.URL) (string, error) {
var result string
// absolute path -- file://c:/absolute/path -> c:/absolute/path
if strings.HasSuffix(uri.Host, ":") {
result = path.Join(uri.Host, uri.Path)
2017-03-01 15:37:05 -05:00
// semi-absolute path (current drive letter)
// -- file:///absolute/path -> /absolute/path
} else if uri.Host == "" && strings.HasPrefix(uri.Path, "/") {
result = path.Join(filepath.VolumeName(base), uri.Path)
2017-03-01 15:37:05 -05:00
// relative path -- file://./relative/path -> ./relative/path
} else if uri.Host == "." {
result = path.Join(base, uri.Path)
2017-03-01 15:37:05 -05:00
// relative path -- file://relative/path -> ./relative/path
} else {
result = path.Join(base, uri.Host, uri.Path)
}
2017-03-01 15:37:05 -05:00
return filepath.ToSlash(result), nil
}
func (d *FileDownloader) Download(dst *os.File, src *url.URL) error {
d.active = false
/* check the uri's scheme to make sure it matches */
if src == nil || src.Scheme != "file" {
return fmt.Errorf("Unexpected uri scheme: %s", src.Scheme)
}
uri := src
/* use the current working directory as the base for relative uri's */
2017-03-01 15:37:05 -05:00
cwd, err := os.Getwd()
if err != nil {
return err
}
/* determine which uri format is being used and convert to a real path */
2017-03-01 15:37:05 -05:00
realpath, err := d.toPath(cwd, *uri)
if err != nil {
return err
}
/* download the file using the operating system's facilities */
d.current = 0
d.active = true
f, err := os.Open(realpath)
2017-03-01 15:37:05 -05:00
if err != nil {
return err
}
defer f.Close()
// get the file size
fi, err := f.Stat()
2017-03-01 15:37:05 -05:00
if err != nil {
return err
}
d.total = uint64(fi.Size())
d.progress.Total = int64(d.total)
progressBar := d.progress.Start()
// no bufferSize specified, so copy synchronously.
if d.bufferSize == nil {
var n int64
2017-03-01 15:37:05 -05:00
n, err = io.Copy(dst, f)
d.active = false
d.current += uint64(n)
progressBar.Set64(int64(d.current))
2017-03-01 15:37:05 -05:00
// use a goro in case someone else wants to enable cancel/resume
} else {
errch := make(chan error)
2017-03-01 15:37:05 -05:00
go func(d *FileDownloader, r io.Reader, w io.Writer, e chan error) {
for d.active {
n, err := io.CopyN(w, r, int64(*d.bufferSize))
if err != nil {
break
}
d.current += uint64(n)
progressBar.Set64(int64(d.current))
}
d.active = false
e <- err
}(d, f, dst, errch)
// ...and we spin until it's done
err = <-errch
}
progressBar.Finish()
f.Close()
return err
}
// SMBDownloader is an implementation of Downloader that downloads
// files using the "\\" path format on Windows
type SMBDownloader struct {
bufferSize *uint
active bool
current uint64
total uint64
progress *pb.ProgressBar
}
func (d *SMBDownloader) Progress() uint64 {
return d.current
}
func (d *SMBDownloader) Total() uint64 {
return d.total
}
func (d *SMBDownloader) Cancel() {
d.active = false
}
func (d *SMBDownloader) Resume() {
// TODO: Implement
}
2017-03-01 15:37:05 -05:00
func (d *SMBDownloader) toPath(base string, uri url.URL) (string, error) {
const UNCPrefix = string(os.PathSeparator) + string(os.PathSeparator)
if runtime.GOOS != "windows" {
2017-03-01 15:37:05 -05:00
return "", fmt.Errorf("Support for SMB based uri's are not supported on %s", runtime.GOOS)
}
return UNCPrefix + filepath.ToSlash(path.Join(uri.Host, uri.Path)), nil
}
func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error {
/* first we warn the world if we're not running windows */
if runtime.GOOS != "windows" {
return fmt.Errorf("Support for SMB based uri's are not supported on %s", runtime.GOOS)
}
d.active = false
/* convert the uri using the net/url module to a UNC path */
if src == nil || src.Scheme != "smb" {
return fmt.Errorf("Unexpected uri scheme: %s", src.Scheme)
}
uri := src
/* use the current working directory as the base for relative uri's */
2017-03-01 15:37:05 -05:00
cwd, err := os.Getwd()
if err != nil {
return err
}
/* convert uri to an smb-path */
2017-03-01 15:37:05 -05:00
realpath, err := d.toPath(cwd, *uri)
if err != nil {
return err
}
/* Open up the "\\"-prefixed path using the Windows filesystem */
d.current = 0
d.active = true
f, err := os.Open(realpath)
2017-03-01 15:37:05 -05:00
if err != nil {
return err
}
defer f.Close()
// get the file size (at the risk of performance)
fi, err := f.Stat()
2017-03-01 15:37:05 -05:00
if err != nil {
return err
}
d.total = uint64(fi.Size())
d.progress.Total = int64(d.total)
progressBar := d.progress.Start()
// no bufferSize specified, so copy synchronously.
if d.bufferSize == nil {
var n int64
2017-03-01 15:37:05 -05:00
n, err = io.Copy(dst, f)
d.active = false
d.current += uint64(n)
progressBar.Set64(int64(d.current))
2017-03-01 15:37:05 -05:00
// use a goro in case someone else wants to enable cancel/resume
} else {
errch := make(chan error)
2017-03-01 15:37:05 -05:00
go func(d *SMBDownloader, r io.Reader, w io.Writer, e chan error) {
for d.active {
n, err := io.CopyN(w, r, int64(*d.bufferSize))
if err != nil {
break
}
d.current += uint64(n)
progressBar.Set64(int64(d.current))
}
d.active = false
e <- err
}(d, f, dst, errch)
// ...and as usual we spin until it's done
err = <-errch
}
progressBar.Finish()
f.Close()
return err
}