Added the file, ftp, and smb downloaders to common/download.go

This commit is contained in:
Ali Rizvi-Santiago 2016-04-05 13:11:30 -05:00
parent da9c94b345
commit 60831801a7
1 changed files with 313 additions and 40 deletions

View File

@ -12,7 +12,6 @@ import (
"hash"
"io"
"log"
"net/http"
"net/url"
"os"
"runtime"
@ -21,6 +20,13 @@ import (
"strings"
)
import (
"net/http"
"github.com/jlaffeye/ftp"
"bufio"
)
// DownloadConfig is the configuration given to instantiate a new
// download instance. Once a configuration is used to instantiate
// a download client, it must not be modified.
@ -78,10 +84,15 @@ func HashForType(t string) hash.Hash {
// NewDownloadClient returns a new DownloadClient for the given
// configuration.
func NewDownloadClient(c *DownloadConfig) *DownloadClient {
const mtu = 1500 /* ethernet */ - 20 /* ipv4 */ - 20 /* tcp */
if c.DownloaderMap == nil {
c.DownloaderMap = map[string]Downloader{
"file": &FileDownloader{bufferSize: nil},
"ftp": &FTPDownloader{userInfo: url.Userinfo{username:"anonymous", password: "anonymous@"}, mtu: mtu},
"http": &HTTPDownloader{userAgent: c.UserAgent},
"https": &HTTPDownloader{userAgent: c.UserAgent},
"smb": &SMBDownloader{bufferSize: nil}
}
}
@ -101,44 +112,6 @@ func (d *DownloadClient) Cancel() {
// TODO(mitchellh): Implement
}
// Take a uri and convert it to a path that makes sense on the Windows platform
func NormalizeWindowsURL(basepath string, url url.URL) string {
// This logic must correspond to the same logic in the NormalizeWindowsURL
// function found in common/config.go since that function _also_ checks that
// the url actually exists in file form.
const UNCPrefix = string(os.PathSeparator)+string(os.PathSeparator)
// move any extra path components that were parsed into Host due
// to UNC into the url.Path field so that it's PathSeparators get
// normalized
if len(url.Host) >= len(UNCPrefix) && url.Host[:len(UNCPrefix)] == UNCPrefix {
idx := strings.Index(url.Host[len(UNCPrefix):], string(os.PathSeparator))
if idx > -1 {
url.Path = filepath.ToSlash(url.Host[idx+len(UNCPrefix):]) + url.Path
url.Host = url.Host[:idx+len(UNCPrefix)]
}
}
// clean up backward-slashes since they only matter when part of a unc path
urlPath := filepath.ToSlash(url.Path)
// semi-absolute path (current drive letter) -- file:///absolute/path
if url.Host == "" && len(urlPath) > 0 && urlPath[0] == '/' {
return path.Join(filepath.VolumeName(basepath), urlPath)
// relative path -- file://./relative/path
// file://relative/path
} else if url.Host == "" || (len(url.Host) > 0 && url.Host[0] == '.') {
return path.Join(filepath.ToSlash(basepath), urlPath)
}
// absolute path
// UNC -- file://\\host/share/whatever
// drive -- file://c:/absolute/path
return path.Join(url.Host, urlPath)
}
func (d *DownloadClient) Get() (string, error) {
// If we already have the file and it matches, then just return the target path.
if verify, _ := d.VerifyChecksum(d.config.TargetPath); verify {
@ -153,6 +126,11 @@ func (d *DownloadClient) Get() (string, error) {
log.Printf("Parsed URL: %#v", u)
/* FIXME:
handle the special case of d.config.CopyFile which returns the path
in an os-specific format.
*/
// Files when we don't copy the file are special cased.
var f *os.File
var finalPath string
@ -271,7 +249,7 @@ func (*HTTPDownloader) Cancel() {
}
func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
log.Printf("Starting download: %s", src.String())
log.Printf("Starting download over HTTP: %s", src.String())
// Seek to the beginning by default
if _, err := dst.Seek(0, 0); err != nil {
@ -350,3 +328,298 @@ func (d *HTTPDownloader) Progress() uint {
func (d *HTTPDownloader) Total() uint {
return d.total
}
// FTPDownloader is an implementation of Downloader that downloads
// files over FTP.
type FTPDownloader struct {
userInfo url.UserInfo
mtu uint
active bool
progress uint
total uint
}
func (*FTPDownloader) Cancel() {
d.active = false
}
func (d *FTPDownloader) Download(dst *os.File, src *url.URL) error {
var userinfo *url.Userinfo
userinfo = d.userInfo
d.active = false
// check the uri is correct
uri, err := url.Parse(src)
if err != nil { return err }
if uri.Scheme != "ftp" {
return fmt.Errorf("Unexpected uri scheme: %s", uri.Scheme)
}
// connect to ftp server
var cli *ftp.ServerConn
log.Printf("Starting download over FTP: %s : %s\n", uri.Host, Uri.Path)
cli,err := ftp.Dial(uri.Host)
if err != nil { return nil }
defer cli.Close()
// handle authentication
if uri.User != nil { userinfo = uri.User }
log.Printf("Authenticating to FTP server: %s : %s\n", uri.User.username, uri.User.password)
err = cli.Login(userinfo.username, userinfo.password)
if err != nil { return err }
// locate specified path
path := path.Dir(uri.Path)
log.Printf("Changing to FTP directory : %s\n", path)
err = cli.ChangeDir(path)
if err != nil { return nil }
curpath,err := cli.CurrentDir()
if err != nil { return err }
log.Printf("Current FTP directory : %s\n", curpath)
// collect stats about the specified file
var name string
var entry *ftp.Entry
_,name = path.Split(uri.Path)
entry = nil
entries,err := cli.List(curpath)
for _,e := range entries {
if e.Type == ftp.EntryTypeFile && e.Name == name {
entry = e
break
}
}
if entry == nil {
return fmt.Errorf("Unable to find file: %s", uri.Path)
}
log.Printf("Found file : %s : %v bytes\n", entry.Name, entry.Size)
d.progress = 0
d.total = entry.Size
// download specified file
d.active = true
reader,err := cli.RetrFrom(uri.Path, d.progress)
if err != nil { return nil }
// do it in a goro so that if someone wants to cancel it, they can
errch := make(chan error)
go func(d *FTPDownloader, r *io.Reader, w *bufio.Writer, e chan error) {
defer w.Flush()
for ; d.active {
n,err := io.CopyN(writer, reader, d.mtu)
if err != nil { break }
d.progress += n
}
d.active = false
e <- err
}(d, reader, bufio.NewWriter(dst), errch)
// spin until it's done
err = <-errch
reader.Close()
if err == nil && d.progress != d.total {
err = fmt.Errorf("FTP total transfer size was %d when %d was expected", d.progress, d.total)
}
// log out and quit
cli.Logout()
cli.Quit()
return err
}
func (d *FTPDownloader) Progress() uint {
return d.progress
}
func (d *FTPDownloader) Total() uint {
return d.total
}
// FileDownloader is an implementation of Downloader that downloads
// files using the regular filesystem.
type FileDownloader struct {
bufferSize *uint
active bool
progress uint
total uint
}
func (*FileDownloader) Cancel() {
d.active = false
}
func (d *FileDownloader) Progress() uint {
return d.progress
}
func (d *FileDownloader) Download(dst *os.File, src *url.URL) error {
d.active = false
/* parse the uri using the net/url module */
uri, err := url.Parse(src)
if uri.Scheme != "file" {
return fmt.Errorf("Unexpected uri scheme: %s", uri.Scheme)
}
/* use the current working directory as the base for relative uri's */
cwd,err := os.Getwd()
if err != nil {
return "", fmt.Errorf("Unable to get working directory")
}
/* determine which uri format is being used and convert to a real path */
var realpath string, basepath string
basepath = filepath.ToSlash(cwd)
// absolute path -- file://c:/absolute/path
if strings.HasSuffix(uri.Host, ":") {
realpath = path.Join(uri.Host, uri.Path)
// semi-absolute path (current drive letter) -- file:///absolute/path
} else if uri.Host == "" && strings.HasPrefix(uri.Path, "/") {
realpath = path.Join(filepath.VolumeName(basepath), uri.Path)
// relative path -- file://./relative/path
} else if uri.Host == "." {
realpath = path.Join(basepath, uri.Path)
// relative path -- file://relative/path
} else {
realpath = path.Join(basepath, uri.Host, uri.Path)
}
/* download the file using the operating system's facilities */
d.progress = 0
d.active = true
f, err = os.Open(realpath)
if err != nil { return err }
defer f.Close()
// get the file size
fi, err := f.Stat()
if err != nil { return err }
d.total = fi.Size()
// no bufferSize specified, so copy synchronously.
if d.bufferSize == nil {
n,err := io.Copy(dst, f)
d.active = false
d.progress += n
// use a goro in case someone else wants to enable cancel/resume
} else {
errch := make(chan error)
go func(d* FileDownloader, r *bufio.Reader, w *bufio.Writer, e chan error) {
defer w.Flush()
for ; d.active {
n,err := io.CopyN(writer, reader, d.bufferSize)
if err != nil { break }
d.progress += n
}
d.active = false
e <- err
}(d, f, bufio.NewWriter(dst), errch)
// ...and we spin until it's done
err = <-errch
}
f.Close()
return err
}
func (d *FileDownloader) Total() uint {
return d.total
}
// SMBDownloader is an implementation of Downloader that downloads
// files using the "\\" path format on Windows
type SMBDownloader struct {
bufferSize *uint
active bool
progress uint
total uint
}
func (*SMBDownloader) Cancel() {
d.active = false
}
func (d *SMBDownloader) Progress() uint {
return d.progress
}
func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error {
const UNCPrefix = string(os.PathSeparator)+string(os.PathSeparator)
d.active = false
if runtime.GOOS != "windows" {
return fmt.Errorf("Support for SMB based uri's are not supported on %s", runtime.GOOS)
}
/* convert the uri using the net/url module to a UNC path */
var realpath string
uri, err := url.Parse(src)
if uri.Scheme != "smb" {
return fmt.Errorf("Unexpected uri scheme: %s", uri.Scheme)
}
realpath = UNCPrefix + filepath.ToSlash(path.Join(uri.Host, uri.Path))
/* Open up the "\\"-prefixed path using the Windows filesystem */
d.progress = 0
d.active = true
f, err = os.Open(realpath)
if err != nil { return err }
defer f.Close()
// get the file size (at the risk of performance)
fi, err := f.Stat()
if err != nil { return err }
d.total = fi.Size()
// no bufferSize specified, so copy synchronously.
if d.bufferSize == nil {
n,err := io.Copy(dst, f)
d.active = false
d.progress += n
// use a goro in case someone else wants to enable cancel/resume
} else {
errch := make(chan error)
go func(d* SMBDownloader, r *bufio.Reader, w *bufio.Writer, e chan error) {
defer w.Flush()
for ; d.active {
n,err := io.CopyN(writer, reader, d.bufferSize)
if err != nil { break }
d.progress += n
}
d.active = false
e <- err
}(d, f, bufio.NewWriter(dst), errch)
// ...and as usual we spin until it's done
err = <-errch
}
f.Close()
return err
}
func (d *SMBDownloader) Total() uint {
return d.total
}