From 6170e24ecbd4a60bb94b6ff3ea11661fb65f8645 Mon Sep 17 00:00:00 2001 From: Ali Rizvi-Santiago Date: Tue, 5 Apr 2016 15:01:06 -0500 Subject: [PATCH] Refactored the code a bit to move the CopyFile hack out of DownloadClient and instead into each protocol. config.go: Removed all of the windows-specific net/url hackery since it's now handled mostly by download.go Removed the replacement of '\' with '/' since url.Parse does it now. Added knowledge of the other protocols implemented in download.go (ftp, smb) Removed some modules that were unused in this commit. download.go: Moved the file-path conversions for the different protocols into their own internally callable functions. Shuffled some of the functions around in case someone wants to implement the ability to resume. Modified DownloadClient.Get to remove the CopyFile special case and trust the protocol implementations if a user doesn't want to copy the file. Since all the protocols except for HTTPDownloader implement Cancel, added a Resume method as a placeholder for another developer to implement. Added a few missing names from their function definitions. Fixed the syntax in a few lines due to my suckage at go. Adjusted the types for progress and total so that they support 64-bit sizes. Removed the usage of the bufio library since it wasn't really being used. --- common/config.go | 124 +++----------- common/download.go | 391 +++++++++++++++++++++++---------------------- 2 files changed, 225 insertions(+), 290 deletions(-) diff --git a/common/config.go b/common/config.go index 7b01191a8..103c51105 100644 --- a/common/config.go +++ b/common/config.go @@ -2,10 +2,8 @@ package common import ( "fmt" - "net/url" "os" "path/filepath" - "runtime" "strings" "time" ) @@ -47,118 +45,40 @@ func ChooseString(vals ...string) string { // a completely valid URL. For example, the original URL might be "local/file.iso" // which isn't a valid URL. DownloadableURL will return "file:///local/file.iso" func DownloadableURL(original string) (string, error) { - if runtime.GOOS == "windows" { - // If the distance to the first ":" is just one character, assume - // we're dealing with a drive letter and thus a file path. - // prepend with "file:///"" now so that url.Parse won't accidentally - // parse the drive letter into the url scheme. - // See https://blogs.msdn.microsoft.com/ie/2006/12/06/file-uris-in-windows/ - // for more info about valid windows URIs - idx := strings.Index(original, ":") - if idx == 1 { - original = "file://" + filepath.ToSlash(original) - } - } - - // XXX: The validation here is later re-parsed in common/download.go and - // thus any modifications here must remain consistent over there too. - uri, err := url.Parse(original) - if err != nil { - return "", err - } - - if uri.Scheme == "" { - uri.Scheme = "file" - } - - const UNCPrefix = string(os.PathSeparator)+string(os.PathSeparator) - if uri.Scheme == "file" { - var ospath string // os-formatted pathname - if runtime.GOOS == "windows" { - // Move any extra path components that were mis-parsed into the Host - // field back into the uri.Path field - if len(uri.Host) >= len(UNCPrefix) && uri.Host[:len(UNCPrefix)] == UNCPrefix { - idx := strings.Index(uri.Host[len(UNCPrefix):], string(os.PathSeparator)) - if idx > -1 { - uri.Path = filepath.ToSlash(uri.Host[idx+len(UNCPrefix):]) + uri.Path - uri.Host = uri.Host[:idx+len(UNCPrefix)] - } - } - // Now all we need to do to convert the uri to a platform-specific path - // is to trade it's slashes for some os.PathSeparator ones. - ospath = uri.Host + filepath.FromSlash(uri.Path) - - } else { - // Since we're already using sane paths on a sane platform, anything in - // uri.Host can be assumed that the user is describing a relative uri. - // This means that if we concatenate it with uri.Path, the filepath - // transform will still open the file correctly. - // i.e. file://localdirectory/filename -> localdirectory/filename - ospath = uri.Host + uri.Path - } - // Only do the filepath transformations if the file appears - // to actually exist. We don't do it on windows, because EvalSymlinks - // won't understand how to handle UNC paths and other Windows-specific minutae. - if _, err := os.Stat(ospath); err == nil && runtime.GOOS != "windows" { - ospath, err = filepath.Abs(ospath) - if err != nil { - return "", err - } - - ospath, err = filepath.EvalSymlinks(ospath) - if err != nil { - return "", err - } - - ospath = filepath.Clean(ospath) - } - - // now that ospath was normalized and such.. - if runtime.GOOS == "windows" { - uri.Host = "" - // Check to see if our ospath is unc-prefixed, and if it is then split - // the UNC host into uri.Host, leaving the rest in ospath. - // This way, our UNC-uri is protected from injury in the call to uri.String() - if len(ospath) >= len(UNCPrefix) && ospath[:len(UNCPrefix)] == UNCPrefix { - idx := strings.Index(ospath[len(UNCPrefix):], string(os.PathSeparator)) - if idx > -1 { - uri.Host = ospath[:len(UNCPrefix)+idx] - ospath = ospath[len(UNCPrefix)+idx:] - } - } - // Restore the uri by re-transforming our os-formatted path - uri.Path = filepath.ToSlash(ospath) - } else { - uri.Host = "" - uri.Path = filepath.ToSlash(ospath) - } - } - - // Make sure it is lowercased - uri.Scheme = strings.ToLower(uri.Scheme) // Verify that the scheme is something we support in our common downloader. - supported := []string{"file", "http", "https"} + supported := []string{"file", "http", "https", "ftp", "smb"} found := false for _, s := range supported { - if uri.Scheme == s { + if strings.HasPrefix(original, s + "://") { found = true break } } - if !found { - return "", fmt.Errorf("Unsupported URL scheme: %s", uri.Scheme) + // If it's properly prefixed with something we support, then we don't need + // to make it a uri. + if found { + return original, nil } - // explicit check to see if we need to manually replace the uri host with a UNC one - if runtime.GOOS == "windows" && uri.Scheme == "file" { - if len(uri.Host) >= len(UNCPrefix) && uri.Host[:len(UNCPrefix)] == UNCPrefix { - escapedHost := url.QueryEscape(uri.Host) - return strings.Replace(uri.String(), escapedHost, uri.Host, 1), nil - } + // If the file exists, then make it an absolute path + _,err := os.Stat(original) + if err == nil { + original, err = filepath.Abs(filepath.FromSlash(original)) + if err != nil { return "", err } + + original, err = filepath.EvalSymlinks(original) + if err != nil { return "", err } + + original = filepath.Clean(original) + original = filepath.ToSlash(original) } - return uri.String(), nil + + // Since it wasn't properly prefixed, let's make it into a well-formed + // file:// uri. + + return "file://" + original, nil } // FileExistsLocally takes the URL output from DownloadableURL, and determines diff --git a/common/download.go b/common/download.go index b0550927f..c06f385e8 100644 --- a/common/download.go +++ b/common/download.go @@ -10,20 +10,20 @@ import ( "errors" "fmt" "hash" - "io" "log" "net/url" "os" "runtime" "path" - "path/filepath" "strings" ) +// imports related to each Downloader implementation import ( + "io" + "path/filepath" "net/http" - "github.com/jlaffeye/ftp" - "bufio" + "github.com/jlaffaye/ftp" ) @@ -89,23 +89,35 @@ func NewDownloadClient(c *DownloadConfig) *DownloadClient { if c.DownloaderMap == nil { c.DownloaderMap = map[string]Downloader{ "file": &FileDownloader{bufferSize: nil}, - "ftp": &FTPDownloader{userInfo: url.Userinfo{username:"anonymous", password: "anonymous@"}, mtu: mtu}, + "ftp": &FTPDownloader{userInfo: url.UserPassword("anonymous", "anonymous@"), mtu: mtu}, "http": &HTTPDownloader{userAgent: c.UserAgent}, "https": &HTTPDownloader{userAgent: c.UserAgent}, - "smb": &SMBDownloader{bufferSize: nil} + "smb": &SMBDownloader{bufferSize: nil}, } } return &DownloadClient{config: c} } -// A downloader is responsible for actually taking a remote URL and -// downloading it. +// A downloader implements the ability to transfer a file, and cancel or resume +// it. type Downloader interface { + Resume() Cancel() + Progress() uint64 + Total() uint64 +} + +// A LocalDownloader is responsible for converting a uri to a local path +// that the platform can open directly. +type LocalDownloader interface { + toPath(string, url.URL) (string,error) +} + +// A RemoteDownloader is responsible for actually taking a remote URL and +// downloading it. +type RemoteDownloader interface { Download(*os.File, *url.URL) error - Progress() uint - Total() uint } func (d *DownloadClient) Cancel() { @@ -119,75 +131,54 @@ func (d *DownloadClient) Get() (string, error) { return d.config.TargetPath, nil } + /* parse the configuration url into a net/url object */ u, err := url.Parse(d.config.Url) - if err != nil { - return "", err - } - + if err != nil { return "", err } log.Printf("Parsed URL: %#v", u) - /* FIXME: - handle the special case of d.config.CopyFile which returns the path - in an os-specific format. - */ + /* use the current working directory as the base for relative uri's */ + cwd,err := os.Getwd() + if err != nil { return "", err } - // Files when we don't copy the file are special cased. - var f *os.File + // Determine which is the correct downloader to use var finalPath string - sourcePath := "" - if u.Scheme == "file" && !d.config.CopyFile { - // This is special case for relative path in this case user specify - // file:../ and after parse destination goes to Opaque - if u.Path != "" { - // If url.Path is set just use this - finalPath = u.Path - } else if u.Opaque != "" { - // otherwise try url.Opaque - finalPath = u.Opaque - } - // This is a special case where we use a source file that already exists - // locally and we don't make a copy. Normally we would copy or download. - log.Printf("[DEBUG] Using local file: %s", finalPath) - // transform the actual file uri to a windowsy path if we're being windowsy. - if runtime.GOOS == "windows" { - // FIXME: cwd should point to a path relative to the TEMPLATE path, - // but since this isn't exposed to us anywhere, we use os.Getwd() - // and assume the user ran packer in the same directory that - // any relative files are located at. - cwd,err := os.Getwd() - if err != nil { - return "", fmt.Errorf("Unable to get working directory") - } - finalPath = NormalizeWindowsURL(cwd, *url) - } + var ok bool + d.downloader, ok = d.config.DownloaderMap[u.Scheme] + if !ok { + return "", fmt.Errorf("No downloader for scheme: %s", u.Scheme) + } - // Keep track of the source so we can make sure not to delete this later - sourcePath = finalPath - if _, err = os.Stat(finalPath); err != nil { - return "", err - } - } else { + remote,ok := d.downloader.(RemoteDownloader) + if !ok { + return "", fmt.Errorf("Unable to treat uri scheme %s as a Downloader : %T", u.Scheme, d.downloader) + } + + local,ok := d.downloader.(LocalDownloader) + if !ok && !d.config.CopyFile{ + return "", fmt.Errorf("Not allowed to use uri scheme %s in no copy file mode : %T", u.Scheme, d.downloader) + } + + // If we're copying the file, then just use the actual downloader + if d.config.CopyFile { + var f *os.File finalPath = d.config.TargetPath - var ok bool - d.downloader, ok = d.config.DownloaderMap[u.Scheme] - if !ok { - return "", fmt.Errorf("No downloader for scheme: %s", u.Scheme) - } - - // Otherwise, download using the downloader. f, err = os.OpenFile(finalPath, os.O_RDWR|os.O_CREATE, os.FileMode(0666)) - if err != nil { - return "", err - } + if err != nil { return "", err } log.Printf("[DEBUG] Downloading: %s", u.String()) - err = d.downloader.Download(f, u) + err = remote.Download(f, u) f.Close() - if err != nil { - return "", err - } + if err != nil { return "", err } + + // Otherwise if our Downloader is a LocalDownloader we can just use the + // path after transforming it. + } else { + finalPath,err = local.toPath(cwd, *u) + if err != nil { return "", err } + + log.Printf("[DEBUG] Using local file: %s", finalPath) } if d.config.Hash != nil { @@ -195,9 +186,7 @@ func (d *DownloadClient) Get() (string, error) { verify, err = d.VerifyChecksum(finalPath) if err == nil && !verify { // Only delete the file if we made a copy or downloaded it - if sourcePath != finalPath { - os.Remove(finalPath) - } + if d.config.CopyFile { os.Remove(finalPath) } err = fmt.Errorf( "checksums didn't match expected: %s", @@ -210,10 +199,7 @@ func (d *DownloadClient) Get() (string, error) { // PercentProgress returns the download progress as a percentage. func (d *DownloadClient) PercentProgress() int { - if d.downloader == nil { - return -1 - } - + if d.downloader == nil { return -1 } return int((float64(d.downloader.Progress()) / float64(d.downloader.Total())) * 100) } @@ -239,12 +225,16 @@ func (d *DownloadClient) VerifyChecksum(path string) (bool, error) { // HTTPDownloader is an implementation of Downloader that downloads // files over HTTP. type HTTPDownloader struct { - progress uint - total uint + progress uint64 + total uint64 userAgent string } -func (*HTTPDownloader) Cancel() { +func (d *HTTPDownloader) Cancel() { + // TODO(mitchellh): Implement +} + +func (d *HTTPDownloader) Resume() { // TODO(mitchellh): Implement } @@ -285,7 +275,7 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { if fi, err := dst.Stat(); err == nil { if _, err = dst.Seek(0, os.SEEK_END); err == nil { req.Header.Set("Range", fmt.Sprintf("bytes=%d-", fi.Size())) - d.progress = uint(fi.Size()) + d.progress = uint64(fi.Size()) } } } @@ -299,7 +289,7 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { return err } - d.total = d.progress + uint(resp.ContentLength) + d.total = d.progress + uint64(resp.ContentLength) var buffer [4096]byte for { n, err := resp.Body.Read(buffer[:]) @@ -307,7 +297,7 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { return err } - d.progress += uint(n) + d.progress += uint64(n) if _, werr := dst.Write(buffer[:n]); werr != nil { return werr @@ -321,29 +311,41 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { return nil } -func (d *HTTPDownloader) Progress() uint { +func (d *HTTPDownloader) Progress() uint64 { return d.progress } -func (d *HTTPDownloader) Total() uint { +func (d *HTTPDownloader) Total() uint64 { return d.total } // FTPDownloader is an implementation of Downloader that downloads // files over FTP. type FTPDownloader struct { - userInfo url.UserInfo + userInfo *url.Userinfo mtu uint active bool - progress uint - total uint + progress uint64 + total uint64 } -func (*FTPDownloader) Cancel() { +func (d *FTPDownloader) Progress() uint64 { + return d.progress +} + +func (d *FTPDownloader) Total() uint64 { + return d.total +} + +func (d *FTPDownloader) Cancel() { d.active = false } +func (d *FTPDownloader) Resume() { + // TODO: Implement +} + func (d *FTPDownloader) Download(dst *os.File, src *url.URL) error { var userinfo *url.Userinfo @@ -351,33 +353,34 @@ func (d *FTPDownloader) Download(dst *os.File, src *url.URL) error { d.active = false // check the uri is correct - uri, err := url.Parse(src) - if err != nil { return err } - - if uri.Scheme != "ftp" { - return fmt.Errorf("Unexpected uri scheme: %s", uri.Scheme) + if src == nil || src.Scheme != "ftp" { + return fmt.Errorf("Unexpected uri scheme: %s", src.Scheme) } + uri := src // connect to ftp server var cli *ftp.ServerConn - log.Printf("Starting download over FTP: %s : %s\n", uri.Host, Uri.Path) + log.Printf("Starting download over FTP: %s : %s\n", uri.Host, uri.Path) cli,err := ftp.Dial(uri.Host) if err != nil { return nil } - defer cli.Close() + defer cli.Quit() // handle authentication if uri.User != nil { userinfo = uri.User } - log.Printf("Authenticating to FTP server: %s : %s\n", uri.User.username, uri.User.password) - err = cli.Login(userinfo.username, userinfo.password) + pass,ok := userinfo.Password() + if !ok { pass = "ftp@" } + + log.Printf("Authenticating to FTP server: %s : %s\n", userinfo.Username(), pass) + err = cli.Login(userinfo.Username(), pass) if err != nil { return err } // locate specified path - path := path.Dir(uri.Path) + p := path.Dir(uri.Path) - log.Printf("Changing to FTP directory : %s\n", path) - err = cli.ChangeDir(path) + log.Printf("Changing to FTP directory : %s\n", p) + err = cli.ChangeDir(p) if err != nil { return nil } curpath,err := cli.CurrentDir() @@ -393,7 +396,7 @@ func (d *FTPDownloader) Download(dst *os.File, src *url.URL) error { entries,err := cli.List(curpath) for _,e := range entries { - if e.Type == ftp.EntryTypeFile && e.Name == name { + if e.Type == ftp.EntryTypeFile && e.Name == name { entry = e break } @@ -414,16 +417,15 @@ func (d *FTPDownloader) Download(dst *os.File, src *url.URL) error { // do it in a goro so that if someone wants to cancel it, they can errch := make(chan error) - go func(d *FTPDownloader, r *io.Reader, w *bufio.Writer, e chan error) { - defer w.Flush() - for ; d.active { - n,err := io.CopyN(writer, reader, d.mtu) + go func(d *FTPDownloader, r io.Reader, w io.Writer, e chan error) { + for ; d.active; { + n,err := io.CopyN(w, r, int64(d.mtu)) if err != nil { break } - d.progress += n + d.progress += uint64(n) } d.active = false e <- err - }(d, reader, bufio.NewWriter(dst), errch) + }(d, reader, dst, errch) // spin until it's done err = <-errch @@ -433,106 +435,109 @@ func (d *FTPDownloader) Download(dst *os.File, src *url.URL) error { err = fmt.Errorf("FTP total transfer size was %d when %d was expected", d.progress, d.total) } - // log out and quit + // log out cli.Logout() - cli.Quit() return err } -func (d *FTPDownloader) Progress() uint { - return d.progress -} - -func (d *FTPDownloader) Total() uint { - return d.total -} - // FileDownloader is an implementation of Downloader that downloads // files using the regular filesystem. type FileDownloader struct { bufferSize *uint active bool - progress uint - total uint + progress uint64 + total uint64 } -func (*FileDownloader) Cancel() { +func (d *FileDownloader) Progress() uint64 { + return d.progress +} + +func (d *FileDownloader) Total() uint64 { + return d.total +} + +func (d *FileDownloader) Cancel() { d.active = false } -func (d *FileDownloader) Progress() uint { - return d.progress +func (d *FileDownloader) Resume() { + // TODO: Implement +} + +func (d *FileDownloader) toPath(base string, uri url.URL) (string,error) { + var result string + + // absolute path -- file://c:/absolute/path -> c:/absolute/path + if strings.HasSuffix(uri.Host, ":") { + result = path.Join(uri.Host, uri.Path) + + // semi-absolute path (current drive letter) + // -- file:///absolute/path -> /absolute/path + } else if uri.Host == "" && strings.HasPrefix(uri.Path, "/") { + result = path.Join(filepath.VolumeName(base), uri.Path) + + // relative path -- file://./relative/path -> ./relative/path + } else if uri.Host == "." { + result = path.Join(base, uri.Path) + + // relative path -- file://relative/path -> ./relative/path + } else { + result = path.Join(base, uri.Host, uri.Path) + } + return filepath.ToSlash(result),nil } func (d *FileDownloader) Download(dst *os.File, src *url.URL) error { d.active = false - /* parse the uri using the net/url module */ - uri, err := url.Parse(src) - if uri.Scheme != "file" { - return fmt.Errorf("Unexpected uri scheme: %s", uri.Scheme) + /* check the uri's scheme to make sure it matches */ + if src == nil || src.Scheme != "file" { + return fmt.Errorf("Unexpected uri scheme: %s", src.Scheme) } + uri := src /* use the current working directory as the base for relative uri's */ cwd,err := os.Getwd() - if err != nil { - return "", fmt.Errorf("Unable to get working directory") - } + if err != nil { return err } /* determine which uri format is being used and convert to a real path */ - var realpath string, basepath string - basepath = filepath.ToSlash(cwd) - - // absolute path -- file://c:/absolute/path - if strings.HasSuffix(uri.Host, ":") { - realpath = path.Join(uri.Host, uri.Path) - - // semi-absolute path (current drive letter) -- file:///absolute/path - } else if uri.Host == "" && strings.HasPrefix(uri.Path, "/") { - realpath = path.Join(filepath.VolumeName(basepath), uri.Path) - - // relative path -- file://./relative/path - } else if uri.Host == "." { - realpath = path.Join(basepath, uri.Path) - - // relative path -- file://relative/path - } else { - realpath = path.Join(basepath, uri.Host, uri.Path) - } + realpath,err := d.toPath(cwd, *uri) + if err != nil { return err } /* download the file using the operating system's facilities */ d.progress = 0 d.active = true - f, err = os.Open(realpath) + f, err := os.Open(realpath) if err != nil { return err } defer f.Close() // get the file size fi, err := f.Stat() if err != nil { return err } - d.total = fi.Size() + d.total = uint64(fi.Size()) // no bufferSize specified, so copy synchronously. if d.bufferSize == nil { - n,err := io.Copy(dst, f) + var n int64 + n,err = io.Copy(dst, f) d.active = false - d.progress += n + d.progress += uint64(n) // use a goro in case someone else wants to enable cancel/resume } else { errch := make(chan error) - go func(d* FileDownloader, r *bufio.Reader, w *bufio.Writer, e chan error) { - defer w.Flush() - for ; d.active { - n,err := io.CopyN(writer, reader, d.bufferSize) + go func(d* FileDownloader, r io.Reader, w io.Writer, e chan error) { + for ; d.active; { + n,err := io.CopyN(w, r, int64(*d.bufferSize)) if err != nil { break } - d.progress += n + d.progress += uint64(n) } d.active = false e <- err - }(d, f, bufio.NewWriter(dst), errch) + }(d, f, dst, errch) // ...and we spin until it's done err = <-errch @@ -541,77 +546,91 @@ func (d *FileDownloader) Download(dst *os.File, src *url.URL) error { return err } -func (d *FileDownloader) Total() uint { - return d.total -} - // SMBDownloader is an implementation of Downloader that downloads // files using the "\\" path format on Windows type SMBDownloader struct { bufferSize *uint active bool - progress uint - total uint + progress uint64 + total uint64 } -func (*SMBDownloader) Cancel() { - d.active = false -} - -func (d *SMBDownloader) Progress() uint { +func (d *SMBDownloader) Progress() uint64 { return d.progress } -func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error { - const UNCPrefix = string(os.PathSeparator)+string(os.PathSeparator) +func (d *SMBDownloader) Total() uint64 { + return d.total +} + +func (d *SMBDownloader) Cancel() { d.active = false +} + +func (d *SMBDownloader) Resume() { + // TODO: Implement +} + +func (d *SMBDownloader) toPath(base string, uri url.URL) (string,error) { + const UNCPrefix = string(os.PathSeparator)+string(os.PathSeparator) if runtime.GOOS != "windows" { - return fmt.Errorf("Support for SMB based uri's are not supported on %s", runtime.GOOS) + return "",fmt.Errorf("Support for SMB based uri's are not supported on %s", runtime.GOOS) } + return UNCPrefix + filepath.ToSlash(path.Join(uri.Host, uri.Path)), nil +} + +func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error { + d.active = false + /* convert the uri using the net/url module to a UNC path */ - var realpath string - uri, err := url.Parse(src) - if uri.Scheme != "smb" { - return fmt.Errorf("Unexpected uri scheme: %s", uri.Scheme) + if src == nil || src.Scheme != "smb" { + return fmt.Errorf("Unexpected uri scheme: %s", src.Scheme) } + uri := src - realpath = UNCPrefix + filepath.ToSlash(path.Join(uri.Host, uri.Path)) + /* use the current working directory as the base for relative uri's */ + cwd,err := os.Getwd() + if err != nil { return err } + + /* convert uri to an smb-path */ + realpath,err := d.toPath(cwd, *uri) + if err != nil { return err } /* Open up the "\\"-prefixed path using the Windows filesystem */ d.progress = 0 d.active = true - f, err = os.Open(realpath) + f, err := os.Open(realpath) if err != nil { return err } defer f.Close() // get the file size (at the risk of performance) fi, err := f.Stat() if err != nil { return err } - d.total = fi.Size() + d.total = uint64(fi.Size()) // no bufferSize specified, so copy synchronously. if d.bufferSize == nil { - n,err := io.Copy(dst, f) + var n int64 + n,err = io.Copy(dst, f) d.active = false - d.progress += n + d.progress += uint64(n) // use a goro in case someone else wants to enable cancel/resume } else { errch := make(chan error) - go func(d* SMBDownloader, r *bufio.Reader, w *bufio.Writer, e chan error) { - defer w.Flush() - for ; d.active { - n,err := io.CopyN(writer, reader, d.bufferSize) + go func(d* SMBDownloader, r io.Reader, w io.Writer, e chan error) { + for ; d.active; { + n,err := io.CopyN(w, r, int64(*d.bufferSize)) if err != nil { break } - d.progress += n + d.progress += uint64(n) } d.active = false e <- err - }(d, f, bufio.NewWriter(dst), errch) + }(d, f, dst, errch) // ...and as usual we spin until it's done err = <-errch @@ -619,7 +638,3 @@ func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error { f.Close() return err } - -func (d *SMBDownloader) Total() uint { - return d.total -}