diff --git a/common/download.go b/common/download.go index 65d72d22a..b0550927f 100644 --- a/common/download.go +++ b/common/download.go @@ -12,7 +12,6 @@ import ( "hash" "io" "log" - "net/http" "net/url" "os" "runtime" @@ -21,6 +20,13 @@ import ( "strings" ) +import ( + "net/http" + "github.com/jlaffeye/ftp" + "bufio" +) + + // DownloadConfig is the configuration given to instantiate a new // download instance. Once a configuration is used to instantiate // a download client, it must not be modified. @@ -78,10 +84,15 @@ func HashForType(t string) hash.Hash { // NewDownloadClient returns a new DownloadClient for the given // configuration. func NewDownloadClient(c *DownloadConfig) *DownloadClient { + const mtu = 1500 /* ethernet */ - 20 /* ipv4 */ - 20 /* tcp */ + if c.DownloaderMap == nil { c.DownloaderMap = map[string]Downloader{ + "file": &FileDownloader{bufferSize: nil}, + "ftp": &FTPDownloader{userInfo: url.Userinfo{username:"anonymous", password: "anonymous@"}, mtu: mtu}, "http": &HTTPDownloader{userAgent: c.UserAgent}, "https": &HTTPDownloader{userAgent: c.UserAgent}, + "smb": &SMBDownloader{bufferSize: nil} } } @@ -101,44 +112,6 @@ func (d *DownloadClient) Cancel() { // TODO(mitchellh): Implement } -// Take a uri and convert it to a path that makes sense on the Windows platform -func NormalizeWindowsURL(basepath string, url url.URL) string { - // This logic must correspond to the same logic in the NormalizeWindowsURL - // function found in common/config.go since that function _also_ checks that - // the url actually exists in file form. - - const UNCPrefix = string(os.PathSeparator)+string(os.PathSeparator) - - // move any extra path components that were parsed into Host due - // to UNC into the url.Path field so that it's PathSeparators get - // normalized - if len(url.Host) >= len(UNCPrefix) && url.Host[:len(UNCPrefix)] == UNCPrefix { - idx := strings.Index(url.Host[len(UNCPrefix):], string(os.PathSeparator)) - if idx > -1 { - url.Path = filepath.ToSlash(url.Host[idx+len(UNCPrefix):]) + url.Path - url.Host = url.Host[:idx+len(UNCPrefix)] - } - } - - // clean up backward-slashes since they only matter when part of a unc path - urlPath := filepath.ToSlash(url.Path) - - // semi-absolute path (current drive letter) -- file:///absolute/path - if url.Host == "" && len(urlPath) > 0 && urlPath[0] == '/' { - return path.Join(filepath.VolumeName(basepath), urlPath) - - // relative path -- file://./relative/path - // file://relative/path - } else if url.Host == "" || (len(url.Host) > 0 && url.Host[0] == '.') { - return path.Join(filepath.ToSlash(basepath), urlPath) - } - - // absolute path - // UNC -- file://\\host/share/whatever - // drive -- file://c:/absolute/path - return path.Join(url.Host, urlPath) -} - func (d *DownloadClient) Get() (string, error) { // If we already have the file and it matches, then just return the target path. if verify, _ := d.VerifyChecksum(d.config.TargetPath); verify { @@ -153,6 +126,11 @@ func (d *DownloadClient) Get() (string, error) { log.Printf("Parsed URL: %#v", u) + /* FIXME: + handle the special case of d.config.CopyFile which returns the path + in an os-specific format. + */ + // Files when we don't copy the file are special cased. var f *os.File var finalPath string @@ -271,7 +249,7 @@ func (*HTTPDownloader) Cancel() { } func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { - log.Printf("Starting download: %s", src.String()) + log.Printf("Starting download over HTTP: %s", src.String()) // Seek to the beginning by default if _, err := dst.Seek(0, 0); err != nil { @@ -350,3 +328,298 @@ func (d *HTTPDownloader) Progress() uint { func (d *HTTPDownloader) Total() uint { return d.total } + +// FTPDownloader is an implementation of Downloader that downloads +// files over FTP. +type FTPDownloader struct { + userInfo url.UserInfo + mtu uint + + active bool + progress uint + total uint +} + +func (*FTPDownloader) Cancel() { + d.active = false +} + +func (d *FTPDownloader) Download(dst *os.File, src *url.URL) error { + var userinfo *url.Userinfo + + userinfo = d.userInfo + d.active = false + + // check the uri is correct + uri, err := url.Parse(src) + if err != nil { return err } + + if uri.Scheme != "ftp" { + return fmt.Errorf("Unexpected uri scheme: %s", uri.Scheme) + } + + // connect to ftp server + var cli *ftp.ServerConn + + log.Printf("Starting download over FTP: %s : %s\n", uri.Host, Uri.Path) + cli,err := ftp.Dial(uri.Host) + if err != nil { return nil } + defer cli.Close() + + // handle authentication + if uri.User != nil { userinfo = uri.User } + + log.Printf("Authenticating to FTP server: %s : %s\n", uri.User.username, uri.User.password) + err = cli.Login(userinfo.username, userinfo.password) + if err != nil { return err } + + // locate specified path + path := path.Dir(uri.Path) + + log.Printf("Changing to FTP directory : %s\n", path) + err = cli.ChangeDir(path) + if err != nil { return nil } + + curpath,err := cli.CurrentDir() + if err != nil { return err } + log.Printf("Current FTP directory : %s\n", curpath) + + // collect stats about the specified file + var name string + var entry *ftp.Entry + + _,name = path.Split(uri.Path) + entry = nil + + entries,err := cli.List(curpath) + for _,e := range entries { + if e.Type == ftp.EntryTypeFile && e.Name == name { + entry = e + break + } + } + + if entry == nil { + return fmt.Errorf("Unable to find file: %s", uri.Path) + } + log.Printf("Found file : %s : %v bytes\n", entry.Name, entry.Size) + + d.progress = 0 + d.total = entry.Size + + // download specified file + d.active = true + reader,err := cli.RetrFrom(uri.Path, d.progress) + if err != nil { return nil } + + // do it in a goro so that if someone wants to cancel it, they can + errch := make(chan error) + go func(d *FTPDownloader, r *io.Reader, w *bufio.Writer, e chan error) { + defer w.Flush() + for ; d.active { + n,err := io.CopyN(writer, reader, d.mtu) + if err != nil { break } + d.progress += n + } + d.active = false + e <- err + }(d, reader, bufio.NewWriter(dst), errch) + + // spin until it's done + err = <-errch + reader.Close() + + if err == nil && d.progress != d.total { + err = fmt.Errorf("FTP total transfer size was %d when %d was expected", d.progress, d.total) + } + + // log out and quit + cli.Logout() + cli.Quit() + return err +} + +func (d *FTPDownloader) Progress() uint { + return d.progress +} + +func (d *FTPDownloader) Total() uint { + return d.total +} + +// FileDownloader is an implementation of Downloader that downloads +// files using the regular filesystem. +type FileDownloader struct { + bufferSize *uint + + active bool + progress uint + total uint +} + +func (*FileDownloader) Cancel() { + d.active = false +} + +func (d *FileDownloader) Progress() uint { + return d.progress +} + +func (d *FileDownloader) Download(dst *os.File, src *url.URL) error { + d.active = false + + /* parse the uri using the net/url module */ + uri, err := url.Parse(src) + if uri.Scheme != "file" { + return fmt.Errorf("Unexpected uri scheme: %s", uri.Scheme) + } + + /* use the current working directory as the base for relative uri's */ + cwd,err := os.Getwd() + if err != nil { + return "", fmt.Errorf("Unable to get working directory") + } + + /* determine which uri format is being used and convert to a real path */ + var realpath string, basepath string + basepath = filepath.ToSlash(cwd) + + // absolute path -- file://c:/absolute/path + if strings.HasSuffix(uri.Host, ":") { + realpath = path.Join(uri.Host, uri.Path) + + // semi-absolute path (current drive letter) -- file:///absolute/path + } else if uri.Host == "" && strings.HasPrefix(uri.Path, "/") { + realpath = path.Join(filepath.VolumeName(basepath), uri.Path) + + // relative path -- file://./relative/path + } else if uri.Host == "." { + realpath = path.Join(basepath, uri.Path) + + // relative path -- file://relative/path + } else { + realpath = path.Join(basepath, uri.Host, uri.Path) + } + + /* download the file using the operating system's facilities */ + d.progress = 0 + d.active = true + + f, err = os.Open(realpath) + if err != nil { return err } + defer f.Close() + + // get the file size + fi, err := f.Stat() + if err != nil { return err } + d.total = fi.Size() + + // no bufferSize specified, so copy synchronously. + if d.bufferSize == nil { + n,err := io.Copy(dst, f) + d.active = false + d.progress += n + + // use a goro in case someone else wants to enable cancel/resume + } else { + errch := make(chan error) + go func(d* FileDownloader, r *bufio.Reader, w *bufio.Writer, e chan error) { + defer w.Flush() + for ; d.active { + n,err := io.CopyN(writer, reader, d.bufferSize) + if err != nil { break } + d.progress += n + } + d.active = false + e <- err + }(d, f, bufio.NewWriter(dst), errch) + + // ...and we spin until it's done + err = <-errch + } + f.Close() + return err +} + +func (d *FileDownloader) Total() uint { + return d.total +} + +// SMBDownloader is an implementation of Downloader that downloads +// files using the "\\" path format on Windows +type SMBDownloader struct { + bufferSize *uint + + active bool + progress uint + total uint +} + +func (*SMBDownloader) Cancel() { + d.active = false +} + +func (d *SMBDownloader) Progress() uint { + return d.progress +} + +func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error { + const UNCPrefix = string(os.PathSeparator)+string(os.PathSeparator) + d.active = false + + if runtime.GOOS != "windows" { + return fmt.Errorf("Support for SMB based uri's are not supported on %s", runtime.GOOS) + } + + /* convert the uri using the net/url module to a UNC path */ + var realpath string + uri, err := url.Parse(src) + if uri.Scheme != "smb" { + return fmt.Errorf("Unexpected uri scheme: %s", uri.Scheme) + } + + realpath = UNCPrefix + filepath.ToSlash(path.Join(uri.Host, uri.Path)) + + /* Open up the "\\"-prefixed path using the Windows filesystem */ + d.progress = 0 + d.active = true + + f, err = os.Open(realpath) + if err != nil { return err } + defer f.Close() + + // get the file size (at the risk of performance) + fi, err := f.Stat() + if err != nil { return err } + d.total = fi.Size() + + // no bufferSize specified, so copy synchronously. + if d.bufferSize == nil { + n,err := io.Copy(dst, f) + d.active = false + d.progress += n + + // use a goro in case someone else wants to enable cancel/resume + } else { + errch := make(chan error) + go func(d* SMBDownloader, r *bufio.Reader, w *bufio.Writer, e chan error) { + defer w.Flush() + for ; d.active { + n,err := io.CopyN(writer, reader, d.bufferSize) + if err != nil { break } + d.progress += n + } + d.active = false + e <- err + }(d, f, bufio.NewWriter(dst), errch) + + // ...and as usual we spin until it's done + err = <-errch + } + f.Close() + return err +} + +func (d *SMBDownloader) Total() uint { + return d.total +}