Added the file, ftp, and smb downloaders to common/download.go
This commit is contained in:
parent
da9c94b345
commit
60831801a7
|
@ -12,7 +12,6 @@ import (
|
|||
"hash"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
|
@ -21,6 +20,13 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"github.com/jlaffeye/ftp"
|
||||
"bufio"
|
||||
)
|
||||
|
||||
|
||||
// DownloadConfig is the configuration given to instantiate a new
|
||||
// download instance. Once a configuration is used to instantiate
|
||||
// a download client, it must not be modified.
|
||||
|
@ -78,10 +84,15 @@ func HashForType(t string) hash.Hash {
|
|||
// NewDownloadClient returns a new DownloadClient for the given
|
||||
// configuration.
|
||||
func NewDownloadClient(c *DownloadConfig) *DownloadClient {
|
||||
const mtu = 1500 /* ethernet */ - 20 /* ipv4 */ - 20 /* tcp */
|
||||
|
||||
if c.DownloaderMap == nil {
|
||||
c.DownloaderMap = map[string]Downloader{
|
||||
"file": &FileDownloader{bufferSize: nil},
|
||||
"ftp": &FTPDownloader{userInfo: url.Userinfo{username:"anonymous", password: "anonymous@"}, mtu: mtu},
|
||||
"http": &HTTPDownloader{userAgent: c.UserAgent},
|
||||
"https": &HTTPDownloader{userAgent: c.UserAgent},
|
||||
"smb": &SMBDownloader{bufferSize: nil}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -101,44 +112,6 @@ func (d *DownloadClient) Cancel() {
|
|||
// TODO(mitchellh): Implement
|
||||
}
|
||||
|
||||
// Take a uri and convert it to a path that makes sense on the Windows platform
|
||||
func NormalizeWindowsURL(basepath string, url url.URL) string {
|
||||
// This logic must correspond to the same logic in the NormalizeWindowsURL
|
||||
// function found in common/config.go since that function _also_ checks that
|
||||
// the url actually exists in file form.
|
||||
|
||||
const UNCPrefix = string(os.PathSeparator)+string(os.PathSeparator)
|
||||
|
||||
// move any extra path components that were parsed into Host due
|
||||
// to UNC into the url.Path field so that it's PathSeparators get
|
||||
// normalized
|
||||
if len(url.Host) >= len(UNCPrefix) && url.Host[:len(UNCPrefix)] == UNCPrefix {
|
||||
idx := strings.Index(url.Host[len(UNCPrefix):], string(os.PathSeparator))
|
||||
if idx > -1 {
|
||||
url.Path = filepath.ToSlash(url.Host[idx+len(UNCPrefix):]) + url.Path
|
||||
url.Host = url.Host[:idx+len(UNCPrefix)]
|
||||
}
|
||||
}
|
||||
|
||||
// clean up backward-slashes since they only matter when part of a unc path
|
||||
urlPath := filepath.ToSlash(url.Path)
|
||||
|
||||
// semi-absolute path (current drive letter) -- file:///absolute/path
|
||||
if url.Host == "" && len(urlPath) > 0 && urlPath[0] == '/' {
|
||||
return path.Join(filepath.VolumeName(basepath), urlPath)
|
||||
|
||||
// relative path -- file://./relative/path
|
||||
// file://relative/path
|
||||
} else if url.Host == "" || (len(url.Host) > 0 && url.Host[0] == '.') {
|
||||
return path.Join(filepath.ToSlash(basepath), urlPath)
|
||||
}
|
||||
|
||||
// absolute path
|
||||
// UNC -- file://\\host/share/whatever
|
||||
// drive -- file://c:/absolute/path
|
||||
return path.Join(url.Host, urlPath)
|
||||
}
|
||||
|
||||
func (d *DownloadClient) Get() (string, error) {
|
||||
// If we already have the file and it matches, then just return the target path.
|
||||
if verify, _ := d.VerifyChecksum(d.config.TargetPath); verify {
|
||||
|
@ -153,6 +126,11 @@ func (d *DownloadClient) Get() (string, error) {
|
|||
|
||||
log.Printf("Parsed URL: %#v", u)
|
||||
|
||||
/* FIXME:
|
||||
handle the special case of d.config.CopyFile which returns the path
|
||||
in an os-specific format.
|
||||
*/
|
||||
|
||||
// Files when we don't copy the file are special cased.
|
||||
var f *os.File
|
||||
var finalPath string
|
||||
|
@ -271,7 +249,7 @@ func (*HTTPDownloader) Cancel() {
|
|||
}
|
||||
|
||||
func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
|
||||
log.Printf("Starting download: %s", src.String())
|
||||
log.Printf("Starting download over HTTP: %s", src.String())
|
||||
|
||||
// Seek to the beginning by default
|
||||
if _, err := dst.Seek(0, 0); err != nil {
|
||||
|
@ -350,3 +328,298 @@ func (d *HTTPDownloader) Progress() uint {
|
|||
func (d *HTTPDownloader) Total() uint {
|
||||
return d.total
|
||||
}
|
||||
|
||||
// FTPDownloader is an implementation of Downloader that downloads
|
||||
// files over FTP.
|
||||
type FTPDownloader struct {
|
||||
userInfo url.UserInfo
|
||||
mtu uint
|
||||
|
||||
active bool
|
||||
progress uint
|
||||
total uint
|
||||
}
|
||||
|
||||
func (*FTPDownloader) Cancel() {
|
||||
d.active = false
|
||||
}
|
||||
|
||||
func (d *FTPDownloader) Download(dst *os.File, src *url.URL) error {
|
||||
var userinfo *url.Userinfo
|
||||
|
||||
userinfo = d.userInfo
|
||||
d.active = false
|
||||
|
||||
// check the uri is correct
|
||||
uri, err := url.Parse(src)
|
||||
if err != nil { return err }
|
||||
|
||||
if uri.Scheme != "ftp" {
|
||||
return fmt.Errorf("Unexpected uri scheme: %s", uri.Scheme)
|
||||
}
|
||||
|
||||
// connect to ftp server
|
||||
var cli *ftp.ServerConn
|
||||
|
||||
log.Printf("Starting download over FTP: %s : %s\n", uri.Host, Uri.Path)
|
||||
cli,err := ftp.Dial(uri.Host)
|
||||
if err != nil { return nil }
|
||||
defer cli.Close()
|
||||
|
||||
// handle authentication
|
||||
if uri.User != nil { userinfo = uri.User }
|
||||
|
||||
log.Printf("Authenticating to FTP server: %s : %s\n", uri.User.username, uri.User.password)
|
||||
err = cli.Login(userinfo.username, userinfo.password)
|
||||
if err != nil { return err }
|
||||
|
||||
// locate specified path
|
||||
path := path.Dir(uri.Path)
|
||||
|
||||
log.Printf("Changing to FTP directory : %s\n", path)
|
||||
err = cli.ChangeDir(path)
|
||||
if err != nil { return nil }
|
||||
|
||||
curpath,err := cli.CurrentDir()
|
||||
if err != nil { return err }
|
||||
log.Printf("Current FTP directory : %s\n", curpath)
|
||||
|
||||
// collect stats about the specified file
|
||||
var name string
|
||||
var entry *ftp.Entry
|
||||
|
||||
_,name = path.Split(uri.Path)
|
||||
entry = nil
|
||||
|
||||
entries,err := cli.List(curpath)
|
||||
for _,e := range entries {
|
||||
if e.Type == ftp.EntryTypeFile && e.Name == name {
|
||||
entry = e
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if entry == nil {
|
||||
return fmt.Errorf("Unable to find file: %s", uri.Path)
|
||||
}
|
||||
log.Printf("Found file : %s : %v bytes\n", entry.Name, entry.Size)
|
||||
|
||||
d.progress = 0
|
||||
d.total = entry.Size
|
||||
|
||||
// download specified file
|
||||
d.active = true
|
||||
reader,err := cli.RetrFrom(uri.Path, d.progress)
|
||||
if err != nil { return nil }
|
||||
|
||||
// do it in a goro so that if someone wants to cancel it, they can
|
||||
errch := make(chan error)
|
||||
go func(d *FTPDownloader, r *io.Reader, w *bufio.Writer, e chan error) {
|
||||
defer w.Flush()
|
||||
for ; d.active {
|
||||
n,err := io.CopyN(writer, reader, d.mtu)
|
||||
if err != nil { break }
|
||||
d.progress += n
|
||||
}
|
||||
d.active = false
|
||||
e <- err
|
||||
}(d, reader, bufio.NewWriter(dst), errch)
|
||||
|
||||
// spin until it's done
|
||||
err = <-errch
|
||||
reader.Close()
|
||||
|
||||
if err == nil && d.progress != d.total {
|
||||
err = fmt.Errorf("FTP total transfer size was %d when %d was expected", d.progress, d.total)
|
||||
}
|
||||
|
||||
// log out and quit
|
||||
cli.Logout()
|
||||
cli.Quit()
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *FTPDownloader) Progress() uint {
|
||||
return d.progress
|
||||
}
|
||||
|
||||
func (d *FTPDownloader) Total() uint {
|
||||
return d.total
|
||||
}
|
||||
|
||||
// FileDownloader is an implementation of Downloader that downloads
|
||||
// files using the regular filesystem.
|
||||
type FileDownloader struct {
|
||||
bufferSize *uint
|
||||
|
||||
active bool
|
||||
progress uint
|
||||
total uint
|
||||
}
|
||||
|
||||
func (*FileDownloader) Cancel() {
|
||||
d.active = false
|
||||
}
|
||||
|
||||
func (d *FileDownloader) Progress() uint {
|
||||
return d.progress
|
||||
}
|
||||
|
||||
func (d *FileDownloader) Download(dst *os.File, src *url.URL) error {
|
||||
d.active = false
|
||||
|
||||
/* parse the uri using the net/url module */
|
||||
uri, err := url.Parse(src)
|
||||
if uri.Scheme != "file" {
|
||||
return fmt.Errorf("Unexpected uri scheme: %s", uri.Scheme)
|
||||
}
|
||||
|
||||
/* use the current working directory as the base for relative uri's */
|
||||
cwd,err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Unable to get working directory")
|
||||
}
|
||||
|
||||
/* determine which uri format is being used and convert to a real path */
|
||||
var realpath string, basepath string
|
||||
basepath = filepath.ToSlash(cwd)
|
||||
|
||||
// absolute path -- file://c:/absolute/path
|
||||
if strings.HasSuffix(uri.Host, ":") {
|
||||
realpath = path.Join(uri.Host, uri.Path)
|
||||
|
||||
// semi-absolute path (current drive letter) -- file:///absolute/path
|
||||
} else if uri.Host == "" && strings.HasPrefix(uri.Path, "/") {
|
||||
realpath = path.Join(filepath.VolumeName(basepath), uri.Path)
|
||||
|
||||
// relative path -- file://./relative/path
|
||||
} else if uri.Host == "." {
|
||||
realpath = path.Join(basepath, uri.Path)
|
||||
|
||||
// relative path -- file://relative/path
|
||||
} else {
|
||||
realpath = path.Join(basepath, uri.Host, uri.Path)
|
||||
}
|
||||
|
||||
/* download the file using the operating system's facilities */
|
||||
d.progress = 0
|
||||
d.active = true
|
||||
|
||||
f, err = os.Open(realpath)
|
||||
if err != nil { return err }
|
||||
defer f.Close()
|
||||
|
||||
// get the file size
|
||||
fi, err := f.Stat()
|
||||
if err != nil { return err }
|
||||
d.total = fi.Size()
|
||||
|
||||
// no bufferSize specified, so copy synchronously.
|
||||
if d.bufferSize == nil {
|
||||
n,err := io.Copy(dst, f)
|
||||
d.active = false
|
||||
d.progress += n
|
||||
|
||||
// use a goro in case someone else wants to enable cancel/resume
|
||||
} else {
|
||||
errch := make(chan error)
|
||||
go func(d* FileDownloader, r *bufio.Reader, w *bufio.Writer, e chan error) {
|
||||
defer w.Flush()
|
||||
for ; d.active {
|
||||
n,err := io.CopyN(writer, reader, d.bufferSize)
|
||||
if err != nil { break }
|
||||
d.progress += n
|
||||
}
|
||||
d.active = false
|
||||
e <- err
|
||||
}(d, f, bufio.NewWriter(dst), errch)
|
||||
|
||||
// ...and we spin until it's done
|
||||
err = <-errch
|
||||
}
|
||||
f.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *FileDownloader) Total() uint {
|
||||
return d.total
|
||||
}
|
||||
|
||||
// SMBDownloader is an implementation of Downloader that downloads
|
||||
// files using the "\\" path format on Windows
|
||||
type SMBDownloader struct {
|
||||
bufferSize *uint
|
||||
|
||||
active bool
|
||||
progress uint
|
||||
total uint
|
||||
}
|
||||
|
||||
func (*SMBDownloader) Cancel() {
|
||||
d.active = false
|
||||
}
|
||||
|
||||
func (d *SMBDownloader) Progress() uint {
|
||||
return d.progress
|
||||
}
|
||||
|
||||
func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error {
|
||||
const UNCPrefix = string(os.PathSeparator)+string(os.PathSeparator)
|
||||
d.active = false
|
||||
|
||||
if runtime.GOOS != "windows" {
|
||||
return fmt.Errorf("Support for SMB based uri's are not supported on %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
/* convert the uri using the net/url module to a UNC path */
|
||||
var realpath string
|
||||
uri, err := url.Parse(src)
|
||||
if uri.Scheme != "smb" {
|
||||
return fmt.Errorf("Unexpected uri scheme: %s", uri.Scheme)
|
||||
}
|
||||
|
||||
realpath = UNCPrefix + filepath.ToSlash(path.Join(uri.Host, uri.Path))
|
||||
|
||||
/* Open up the "\\"-prefixed path using the Windows filesystem */
|
||||
d.progress = 0
|
||||
d.active = true
|
||||
|
||||
f, err = os.Open(realpath)
|
||||
if err != nil { return err }
|
||||
defer f.Close()
|
||||
|
||||
// get the file size (at the risk of performance)
|
||||
fi, err := f.Stat()
|
||||
if err != nil { return err }
|
||||
d.total = fi.Size()
|
||||
|
||||
// no bufferSize specified, so copy synchronously.
|
||||
if d.bufferSize == nil {
|
||||
n,err := io.Copy(dst, f)
|
||||
d.active = false
|
||||
d.progress += n
|
||||
|
||||
// use a goro in case someone else wants to enable cancel/resume
|
||||
} else {
|
||||
errch := make(chan error)
|
||||
go func(d* SMBDownloader, r *bufio.Reader, w *bufio.Writer, e chan error) {
|
||||
defer w.Flush()
|
||||
for ; d.active {
|
||||
n,err := io.CopyN(writer, reader, d.bufferSize)
|
||||
if err != nil { break }
|
||||
d.progress += n
|
||||
}
|
||||
d.active = false
|
||||
e <- err
|
||||
}(d, f, bufio.NewWriter(dst), errch)
|
||||
|
||||
// ...and as usual we spin until it's done
|
||||
err = <-errch
|
||||
}
|
||||
f.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *SMBDownloader) Total() uint {
|
||||
return d.total
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue