Refactored the code a bit to move the CopyFile hack out of DownloadClient and instead into each protocol.

config.go:
Removed all of the windows-specific net/url hackery since it's now handled mostly by download.go
Removed the replacement of '\' with '/' since url.Parse does it now.
Added knowledge of the other protocols implemented in download.go (ftp, smb)
Removed some modules that were unused in this commit.

download.go:
Moved the file-path conversions for the different protocols into their own internally callable functions.
Shuffled some of the functions around in case someone wants to implement the ability to resume.
Modified DownloadClient.Get to remove the CopyFile special case and trust the protocol implementations if a user doesn't want to copy the file.
Since all the protocols except for HTTPDownloader implement Cancel, added a Resume method as a placeholder for another developer to implement.
Added a few missing names from their function definitions.
Fixed the syntax in a few lines due to my suckage at go.
Adjusted the types for progress and total so that they support 64-bit sizes.
Removed the usage of the bufio library since it wasn't really being used.
This commit is contained in:
Ali Rizvi-Santiago 2016-04-05 15:01:06 -05:00
parent 60831801a7
commit 6170e24ecb
2 changed files with 225 additions and 290 deletions

View File

@ -2,10 +2,8 @@ package common
import ( import (
"fmt" "fmt"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"strings" "strings"
"time" "time"
) )
@ -47,118 +45,40 @@ func ChooseString(vals ...string) string {
// a completely valid URL. For example, the original URL might be "local/file.iso" // a completely valid URL. For example, the original URL might be "local/file.iso"
// which isn't a valid URL. DownloadableURL will return "file:///local/file.iso" // which isn't a valid URL. DownloadableURL will return "file:///local/file.iso"
func DownloadableURL(original string) (string, error) { func DownloadableURL(original string) (string, error) {
if runtime.GOOS == "windows" {
// If the distance to the first ":" is just one character, assume
// we're dealing with a drive letter and thus a file path.
// prepend with "file:///"" now so that url.Parse won't accidentally
// parse the drive letter into the url scheme.
// See https://blogs.msdn.microsoft.com/ie/2006/12/06/file-uris-in-windows/
// for more info about valid windows URIs
idx := strings.Index(original, ":")
if idx == 1 {
original = "file://" + filepath.ToSlash(original)
}
}
// XXX: The validation here is later re-parsed in common/download.go and
// thus any modifications here must remain consistent over there too.
uri, err := url.Parse(original)
if err != nil {
return "", err
}
if uri.Scheme == "" {
uri.Scheme = "file"
}
const UNCPrefix = string(os.PathSeparator)+string(os.PathSeparator)
if uri.Scheme == "file" {
var ospath string // os-formatted pathname
if runtime.GOOS == "windows" {
// Move any extra path components that were mis-parsed into the Host
// field back into the uri.Path field
if len(uri.Host) >= len(UNCPrefix) && uri.Host[:len(UNCPrefix)] == UNCPrefix {
idx := strings.Index(uri.Host[len(UNCPrefix):], string(os.PathSeparator))
if idx > -1 {
uri.Path = filepath.ToSlash(uri.Host[idx+len(UNCPrefix):]) + uri.Path
uri.Host = uri.Host[:idx+len(UNCPrefix)]
}
}
// Now all we need to do to convert the uri to a platform-specific path
// is to trade it's slashes for some os.PathSeparator ones.
ospath = uri.Host + filepath.FromSlash(uri.Path)
} else {
// Since we're already using sane paths on a sane platform, anything in
// uri.Host can be assumed that the user is describing a relative uri.
// This means that if we concatenate it with uri.Path, the filepath
// transform will still open the file correctly.
// i.e. file://localdirectory/filename -> localdirectory/filename
ospath = uri.Host + uri.Path
}
// Only do the filepath transformations if the file appears
// to actually exist. We don't do it on windows, because EvalSymlinks
// won't understand how to handle UNC paths and other Windows-specific minutae.
if _, err := os.Stat(ospath); err == nil && runtime.GOOS != "windows" {
ospath, err = filepath.Abs(ospath)
if err != nil {
return "", err
}
ospath, err = filepath.EvalSymlinks(ospath)
if err != nil {
return "", err
}
ospath = filepath.Clean(ospath)
}
// now that ospath was normalized and such..
if runtime.GOOS == "windows" {
uri.Host = ""
// Check to see if our ospath is unc-prefixed, and if it is then split
// the UNC host into uri.Host, leaving the rest in ospath.
// This way, our UNC-uri is protected from injury in the call to uri.String()
if len(ospath) >= len(UNCPrefix) && ospath[:len(UNCPrefix)] == UNCPrefix {
idx := strings.Index(ospath[len(UNCPrefix):], string(os.PathSeparator))
if idx > -1 {
uri.Host = ospath[:len(UNCPrefix)+idx]
ospath = ospath[len(UNCPrefix)+idx:]
}
}
// Restore the uri by re-transforming our os-formatted path
uri.Path = filepath.ToSlash(ospath)
} else {
uri.Host = ""
uri.Path = filepath.ToSlash(ospath)
}
}
// Make sure it is lowercased
uri.Scheme = strings.ToLower(uri.Scheme)
// Verify that the scheme is something we support in our common downloader. // Verify that the scheme is something we support in our common downloader.
supported := []string{"file", "http", "https"} supported := []string{"file", "http", "https", "ftp", "smb"}
found := false found := false
for _, s := range supported { for _, s := range supported {
if uri.Scheme == s { if strings.HasPrefix(original, s + "://") {
found = true found = true
break break
} }
} }
if !found { // If it's properly prefixed with something we support, then we don't need
return "", fmt.Errorf("Unsupported URL scheme: %s", uri.Scheme) // to make it a uri.
if found {
return original, nil
} }
// explicit check to see if we need to manually replace the uri host with a UNC one // If the file exists, then make it an absolute path
if runtime.GOOS == "windows" && uri.Scheme == "file" { _,err := os.Stat(original)
if len(uri.Host) >= len(UNCPrefix) && uri.Host[:len(UNCPrefix)] == UNCPrefix { if err == nil {
escapedHost := url.QueryEscape(uri.Host) original, err = filepath.Abs(filepath.FromSlash(original))
return strings.Replace(uri.String(), escapedHost, uri.Host, 1), nil if err != nil { return "", err }
}
original, err = filepath.EvalSymlinks(original)
if err != nil { return "", err }
original = filepath.Clean(original)
original = filepath.ToSlash(original)
} }
return uri.String(), nil
// Since it wasn't properly prefixed, let's make it into a well-formed
// file:// uri.
return "file://" + original, nil
} }
// FileExistsLocally takes the URL output from DownloadableURL, and determines // FileExistsLocally takes the URL output from DownloadableURL, and determines

View File

@ -10,20 +10,20 @@ import (
"errors" "errors"
"fmt" "fmt"
"hash" "hash"
"io"
"log" "log"
"net/url" "net/url"
"os" "os"
"runtime" "runtime"
"path" "path"
"path/filepath"
"strings" "strings"
) )
// imports related to each Downloader implementation
import ( import (
"io"
"path/filepath"
"net/http" "net/http"
"github.com/jlaffeye/ftp" "github.com/jlaffaye/ftp"
"bufio"
) )
@ -89,23 +89,35 @@ func NewDownloadClient(c *DownloadConfig) *DownloadClient {
if c.DownloaderMap == nil { if c.DownloaderMap == nil {
c.DownloaderMap = map[string]Downloader{ c.DownloaderMap = map[string]Downloader{
"file": &FileDownloader{bufferSize: nil}, "file": &FileDownloader{bufferSize: nil},
"ftp": &FTPDownloader{userInfo: url.Userinfo{username:"anonymous", password: "anonymous@"}, mtu: mtu}, "ftp": &FTPDownloader{userInfo: url.UserPassword("anonymous", "anonymous@"), mtu: mtu},
"http": &HTTPDownloader{userAgent: c.UserAgent}, "http": &HTTPDownloader{userAgent: c.UserAgent},
"https": &HTTPDownloader{userAgent: c.UserAgent}, "https": &HTTPDownloader{userAgent: c.UserAgent},
"smb": &SMBDownloader{bufferSize: nil} "smb": &SMBDownloader{bufferSize: nil},
} }
} }
return &DownloadClient{config: c} return &DownloadClient{config: c}
} }
// A downloader is responsible for actually taking a remote URL and // A downloader implements the ability to transfer a file, and cancel or resume
// downloading it. // it.
type Downloader interface { type Downloader interface {
Resume()
Cancel() Cancel()
Progress() uint64
Total() uint64
}
// A LocalDownloader is responsible for converting a uri to a local path
// that the platform can open directly.
type LocalDownloader interface {
toPath(string, url.URL) (string,error)
}
// A RemoteDownloader is responsible for actually taking a remote URL and
// downloading it.
type RemoteDownloader interface {
Download(*os.File, *url.URL) error Download(*os.File, *url.URL) error
Progress() uint
Total() uint
} }
func (d *DownloadClient) Cancel() { func (d *DownloadClient) Cancel() {
@ -119,75 +131,54 @@ func (d *DownloadClient) Get() (string, error) {
return d.config.TargetPath, nil return d.config.TargetPath, nil
} }
/* parse the configuration url into a net/url object */
u, err := url.Parse(d.config.Url) u, err := url.Parse(d.config.Url)
if err != nil { if err != nil { return "", err }
return "", err
}
log.Printf("Parsed URL: %#v", u) log.Printf("Parsed URL: %#v", u)
/* FIXME: /* use the current working directory as the base for relative uri's */
handle the special case of d.config.CopyFile which returns the path cwd,err := os.Getwd()
in an os-specific format. if err != nil { return "", err }
*/
// Files when we don't copy the file are special cased. // Determine which is the correct downloader to use
var f *os.File
var finalPath string var finalPath string
sourcePath := ""
if u.Scheme == "file" && !d.config.CopyFile {
// This is special case for relative path in this case user specify
// file:../ and after parse destination goes to Opaque
if u.Path != "" {
// If url.Path is set just use this
finalPath = u.Path
} else if u.Opaque != "" {
// otherwise try url.Opaque
finalPath = u.Opaque
}
// This is a special case where we use a source file that already exists
// locally and we don't make a copy. Normally we would copy or download.
log.Printf("[DEBUG] Using local file: %s", finalPath)
// transform the actual file uri to a windowsy path if we're being windowsy. var ok bool
if runtime.GOOS == "windows" { d.downloader, ok = d.config.DownloaderMap[u.Scheme]
// FIXME: cwd should point to a path relative to the TEMPLATE path, if !ok {
// but since this isn't exposed to us anywhere, we use os.Getwd() return "", fmt.Errorf("No downloader for scheme: %s", u.Scheme)
// and assume the user ran packer in the same directory that }
// any relative files are located at.
cwd,err := os.Getwd()
if err != nil {
return "", fmt.Errorf("Unable to get working directory")
}
finalPath = NormalizeWindowsURL(cwd, *url)
}
// Keep track of the source so we can make sure not to delete this later remote,ok := d.downloader.(RemoteDownloader)
sourcePath = finalPath if !ok {
if _, err = os.Stat(finalPath); err != nil { return "", fmt.Errorf("Unable to treat uri scheme %s as a Downloader : %T", u.Scheme, d.downloader)
return "", err }
}
} else { local,ok := d.downloader.(LocalDownloader)
if !ok && !d.config.CopyFile{
return "", fmt.Errorf("Not allowed to use uri scheme %s in no copy file mode : %T", u.Scheme, d.downloader)
}
// If we're copying the file, then just use the actual downloader
if d.config.CopyFile {
var f *os.File
finalPath = d.config.TargetPath finalPath = d.config.TargetPath
var ok bool
d.downloader, ok = d.config.DownloaderMap[u.Scheme]
if !ok {
return "", fmt.Errorf("No downloader for scheme: %s", u.Scheme)
}
// Otherwise, download using the downloader.
f, err = os.OpenFile(finalPath, os.O_RDWR|os.O_CREATE, os.FileMode(0666)) f, err = os.OpenFile(finalPath, os.O_RDWR|os.O_CREATE, os.FileMode(0666))
if err != nil { if err != nil { return "", err }
return "", err
}
log.Printf("[DEBUG] Downloading: %s", u.String()) log.Printf("[DEBUG] Downloading: %s", u.String())
err = d.downloader.Download(f, u) err = remote.Download(f, u)
f.Close() f.Close()
if err != nil { if err != nil { return "", err }
return "", err
} // Otherwise if our Downloader is a LocalDownloader we can just use the
// path after transforming it.
} else {
finalPath,err = local.toPath(cwd, *u)
if err != nil { return "", err }
log.Printf("[DEBUG] Using local file: %s", finalPath)
} }
if d.config.Hash != nil { if d.config.Hash != nil {
@ -195,9 +186,7 @@ func (d *DownloadClient) Get() (string, error) {
verify, err = d.VerifyChecksum(finalPath) verify, err = d.VerifyChecksum(finalPath)
if err == nil && !verify { if err == nil && !verify {
// Only delete the file if we made a copy or downloaded it // Only delete the file if we made a copy or downloaded it
if sourcePath != finalPath { if d.config.CopyFile { os.Remove(finalPath) }
os.Remove(finalPath)
}
err = fmt.Errorf( err = fmt.Errorf(
"checksums didn't match expected: %s", "checksums didn't match expected: %s",
@ -210,10 +199,7 @@ func (d *DownloadClient) Get() (string, error) {
// PercentProgress returns the download progress as a percentage. // PercentProgress returns the download progress as a percentage.
func (d *DownloadClient) PercentProgress() int { func (d *DownloadClient) PercentProgress() int {
if d.downloader == nil { if d.downloader == nil { return -1 }
return -1
}
return int((float64(d.downloader.Progress()) / float64(d.downloader.Total())) * 100) return int((float64(d.downloader.Progress()) / float64(d.downloader.Total())) * 100)
} }
@ -239,12 +225,16 @@ func (d *DownloadClient) VerifyChecksum(path string) (bool, error) {
// HTTPDownloader is an implementation of Downloader that downloads // HTTPDownloader is an implementation of Downloader that downloads
// files over HTTP. // files over HTTP.
type HTTPDownloader struct { type HTTPDownloader struct {
progress uint progress uint64
total uint total uint64
userAgent string userAgent string
} }
func (*HTTPDownloader) Cancel() { func (d *HTTPDownloader) Cancel() {
// TODO(mitchellh): Implement
}
func (d *HTTPDownloader) Resume() {
// TODO(mitchellh): Implement // TODO(mitchellh): Implement
} }
@ -285,7 +275,7 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
if fi, err := dst.Stat(); err == nil { if fi, err := dst.Stat(); err == nil {
if _, err = dst.Seek(0, os.SEEK_END); err == nil { if _, err = dst.Seek(0, os.SEEK_END); err == nil {
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", fi.Size())) req.Header.Set("Range", fmt.Sprintf("bytes=%d-", fi.Size()))
d.progress = uint(fi.Size()) d.progress = uint64(fi.Size())
} }
} }
} }
@ -299,7 +289,7 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
return err return err
} }
d.total = d.progress + uint(resp.ContentLength) d.total = d.progress + uint64(resp.ContentLength)
var buffer [4096]byte var buffer [4096]byte
for { for {
n, err := resp.Body.Read(buffer[:]) n, err := resp.Body.Read(buffer[:])
@ -307,7 +297,7 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
return err return err
} }
d.progress += uint(n) d.progress += uint64(n)
if _, werr := dst.Write(buffer[:n]); werr != nil { if _, werr := dst.Write(buffer[:n]); werr != nil {
return werr return werr
@ -321,29 +311,41 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
return nil return nil
} }
func (d *HTTPDownloader) Progress() uint { func (d *HTTPDownloader) Progress() uint64 {
return d.progress return d.progress
} }
func (d *HTTPDownloader) Total() uint { func (d *HTTPDownloader) Total() uint64 {
return d.total return d.total
} }
// FTPDownloader is an implementation of Downloader that downloads // FTPDownloader is an implementation of Downloader that downloads
// files over FTP. // files over FTP.
type FTPDownloader struct { type FTPDownloader struct {
userInfo url.UserInfo userInfo *url.Userinfo
mtu uint mtu uint
active bool active bool
progress uint progress uint64
total uint total uint64
} }
func (*FTPDownloader) Cancel() { func (d *FTPDownloader) Progress() uint64 {
return d.progress
}
func (d *FTPDownloader) Total() uint64 {
return d.total
}
func (d *FTPDownloader) Cancel() {
d.active = false d.active = false
} }
func (d *FTPDownloader) Resume() {
// TODO: Implement
}
func (d *FTPDownloader) Download(dst *os.File, src *url.URL) error { func (d *FTPDownloader) Download(dst *os.File, src *url.URL) error {
var userinfo *url.Userinfo var userinfo *url.Userinfo
@ -351,33 +353,34 @@ func (d *FTPDownloader) Download(dst *os.File, src *url.URL) error {
d.active = false d.active = false
// check the uri is correct // check the uri is correct
uri, err := url.Parse(src) if src == nil || src.Scheme != "ftp" {
if err != nil { return err } return fmt.Errorf("Unexpected uri scheme: %s", src.Scheme)
if uri.Scheme != "ftp" {
return fmt.Errorf("Unexpected uri scheme: %s", uri.Scheme)
} }
uri := src
// connect to ftp server // connect to ftp server
var cli *ftp.ServerConn var cli *ftp.ServerConn
log.Printf("Starting download over FTP: %s : %s\n", uri.Host, Uri.Path) log.Printf("Starting download over FTP: %s : %s\n", uri.Host, uri.Path)
cli,err := ftp.Dial(uri.Host) cli,err := ftp.Dial(uri.Host)
if err != nil { return nil } if err != nil { return nil }
defer cli.Close() defer cli.Quit()
// handle authentication // handle authentication
if uri.User != nil { userinfo = uri.User } if uri.User != nil { userinfo = uri.User }
log.Printf("Authenticating to FTP server: %s : %s\n", uri.User.username, uri.User.password) pass,ok := userinfo.Password()
err = cli.Login(userinfo.username, userinfo.password) if !ok { pass = "ftp@" }
log.Printf("Authenticating to FTP server: %s : %s\n", userinfo.Username(), pass)
err = cli.Login(userinfo.Username(), pass)
if err != nil { return err } if err != nil { return err }
// locate specified path // locate specified path
path := path.Dir(uri.Path) p := path.Dir(uri.Path)
log.Printf("Changing to FTP directory : %s\n", path) log.Printf("Changing to FTP directory : %s\n", p)
err = cli.ChangeDir(path) err = cli.ChangeDir(p)
if err != nil { return nil } if err != nil { return nil }
curpath,err := cli.CurrentDir() curpath,err := cli.CurrentDir()
@ -393,7 +396,7 @@ func (d *FTPDownloader) Download(dst *os.File, src *url.URL) error {
entries,err := cli.List(curpath) entries,err := cli.List(curpath)
for _,e := range entries { for _,e := range entries {
if e.Type == ftp.EntryTypeFile && e.Name == name { if e.Type == ftp.EntryTypeFile && e.Name == name {
entry = e entry = e
break break
} }
@ -414,16 +417,15 @@ func (d *FTPDownloader) Download(dst *os.File, src *url.URL) error {
// do it in a goro so that if someone wants to cancel it, they can // do it in a goro so that if someone wants to cancel it, they can
errch := make(chan error) errch := make(chan error)
go func(d *FTPDownloader, r *io.Reader, w *bufio.Writer, e chan error) { go func(d *FTPDownloader, r io.Reader, w io.Writer, e chan error) {
defer w.Flush() for ; d.active; {
for ; d.active { n,err := io.CopyN(w, r, int64(d.mtu))
n,err := io.CopyN(writer, reader, d.mtu)
if err != nil { break } if err != nil { break }
d.progress += n d.progress += uint64(n)
} }
d.active = false d.active = false
e <- err e <- err
}(d, reader, bufio.NewWriter(dst), errch) }(d, reader, dst, errch)
// spin until it's done // spin until it's done
err = <-errch err = <-errch
@ -433,106 +435,109 @@ func (d *FTPDownloader) Download(dst *os.File, src *url.URL) error {
err = fmt.Errorf("FTP total transfer size was %d when %d was expected", d.progress, d.total) err = fmt.Errorf("FTP total transfer size was %d when %d was expected", d.progress, d.total)
} }
// log out and quit // log out
cli.Logout() cli.Logout()
cli.Quit()
return err return err
} }
func (d *FTPDownloader) Progress() uint {
return d.progress
}
func (d *FTPDownloader) Total() uint {
return d.total
}
// FileDownloader is an implementation of Downloader that downloads // FileDownloader is an implementation of Downloader that downloads
// files using the regular filesystem. // files using the regular filesystem.
type FileDownloader struct { type FileDownloader struct {
bufferSize *uint bufferSize *uint
active bool active bool
progress uint progress uint64
total uint total uint64
} }
func (*FileDownloader) Cancel() { func (d *FileDownloader) Progress() uint64 {
return d.progress
}
func (d *FileDownloader) Total() uint64 {
return d.total
}
func (d *FileDownloader) Cancel() {
d.active = false d.active = false
} }
func (d *FileDownloader) Progress() uint { func (d *FileDownloader) Resume() {
return d.progress // TODO: Implement
}
func (d *FileDownloader) toPath(base string, uri url.URL) (string,error) {
var result string
// absolute path -- file://c:/absolute/path -> c:/absolute/path
if strings.HasSuffix(uri.Host, ":") {
result = path.Join(uri.Host, uri.Path)
// semi-absolute path (current drive letter)
// -- file:///absolute/path -> /absolute/path
} else if uri.Host == "" && strings.HasPrefix(uri.Path, "/") {
result = path.Join(filepath.VolumeName(base), uri.Path)
// relative path -- file://./relative/path -> ./relative/path
} else if uri.Host == "." {
result = path.Join(base, uri.Path)
// relative path -- file://relative/path -> ./relative/path
} else {
result = path.Join(base, uri.Host, uri.Path)
}
return filepath.ToSlash(result),nil
} }
func (d *FileDownloader) Download(dst *os.File, src *url.URL) error { func (d *FileDownloader) Download(dst *os.File, src *url.URL) error {
d.active = false d.active = false
/* parse the uri using the net/url module */ /* check the uri's scheme to make sure it matches */
uri, err := url.Parse(src) if src == nil || src.Scheme != "file" {
if uri.Scheme != "file" { return fmt.Errorf("Unexpected uri scheme: %s", src.Scheme)
return fmt.Errorf("Unexpected uri scheme: %s", uri.Scheme)
} }
uri := src
/* use the current working directory as the base for relative uri's */ /* use the current working directory as the base for relative uri's */
cwd,err := os.Getwd() cwd,err := os.Getwd()
if err != nil { if err != nil { return err }
return "", fmt.Errorf("Unable to get working directory")
}
/* determine which uri format is being used and convert to a real path */ /* determine which uri format is being used and convert to a real path */
var realpath string, basepath string realpath,err := d.toPath(cwd, *uri)
basepath = filepath.ToSlash(cwd) if err != nil { return err }
// absolute path -- file://c:/absolute/path
if strings.HasSuffix(uri.Host, ":") {
realpath = path.Join(uri.Host, uri.Path)
// semi-absolute path (current drive letter) -- file:///absolute/path
} else if uri.Host == "" && strings.HasPrefix(uri.Path, "/") {
realpath = path.Join(filepath.VolumeName(basepath), uri.Path)
// relative path -- file://./relative/path
} else if uri.Host == "." {
realpath = path.Join(basepath, uri.Path)
// relative path -- file://relative/path
} else {
realpath = path.Join(basepath, uri.Host, uri.Path)
}
/* download the file using the operating system's facilities */ /* download the file using the operating system's facilities */
d.progress = 0 d.progress = 0
d.active = true d.active = true
f, err = os.Open(realpath) f, err := os.Open(realpath)
if err != nil { return err } if err != nil { return err }
defer f.Close() defer f.Close()
// get the file size // get the file size
fi, err := f.Stat() fi, err := f.Stat()
if err != nil { return err } if err != nil { return err }
d.total = fi.Size() d.total = uint64(fi.Size())
// no bufferSize specified, so copy synchronously. // no bufferSize specified, so copy synchronously.
if d.bufferSize == nil { if d.bufferSize == nil {
n,err := io.Copy(dst, f) var n int64
n,err = io.Copy(dst, f)
d.active = false d.active = false
d.progress += n d.progress += uint64(n)
// use a goro in case someone else wants to enable cancel/resume // use a goro in case someone else wants to enable cancel/resume
} else { } else {
errch := make(chan error) errch := make(chan error)
go func(d* FileDownloader, r *bufio.Reader, w *bufio.Writer, e chan error) { go func(d* FileDownloader, r io.Reader, w io.Writer, e chan error) {
defer w.Flush() for ; d.active; {
for ; d.active { n,err := io.CopyN(w, r, int64(*d.bufferSize))
n,err := io.CopyN(writer, reader, d.bufferSize)
if err != nil { break } if err != nil { break }
d.progress += n d.progress += uint64(n)
} }
d.active = false d.active = false
e <- err e <- err
}(d, f, bufio.NewWriter(dst), errch) }(d, f, dst, errch)
// ...and we spin until it's done // ...and we spin until it's done
err = <-errch err = <-errch
@ -541,77 +546,91 @@ func (d *FileDownloader) Download(dst *os.File, src *url.URL) error {
return err return err
} }
func (d *FileDownloader) Total() uint {
return d.total
}
// SMBDownloader is an implementation of Downloader that downloads // SMBDownloader is an implementation of Downloader that downloads
// files using the "\\" path format on Windows // files using the "\\" path format on Windows
type SMBDownloader struct { type SMBDownloader struct {
bufferSize *uint bufferSize *uint
active bool active bool
progress uint progress uint64
total uint total uint64
} }
func (*SMBDownloader) Cancel() { func (d *SMBDownloader) Progress() uint64 {
d.active = false
}
func (d *SMBDownloader) Progress() uint {
return d.progress return d.progress
} }
func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error { func (d *SMBDownloader) Total() uint64 {
const UNCPrefix = string(os.PathSeparator)+string(os.PathSeparator) return d.total
}
func (d *SMBDownloader) Cancel() {
d.active = false d.active = false
}
func (d *SMBDownloader) Resume() {
// TODO: Implement
}
func (d *SMBDownloader) toPath(base string, uri url.URL) (string,error) {
const UNCPrefix = string(os.PathSeparator)+string(os.PathSeparator)
if runtime.GOOS != "windows" { if runtime.GOOS != "windows" {
return fmt.Errorf("Support for SMB based uri's are not supported on %s", runtime.GOOS) return "",fmt.Errorf("Support for SMB based uri's are not supported on %s", runtime.GOOS)
} }
return UNCPrefix + filepath.ToSlash(path.Join(uri.Host, uri.Path)), nil
}
func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error {
d.active = false
/* convert the uri using the net/url module to a UNC path */ /* convert the uri using the net/url module to a UNC path */
var realpath string if src == nil || src.Scheme != "smb" {
uri, err := url.Parse(src) return fmt.Errorf("Unexpected uri scheme: %s", src.Scheme)
if uri.Scheme != "smb" {
return fmt.Errorf("Unexpected uri scheme: %s", uri.Scheme)
} }
uri := src
realpath = UNCPrefix + filepath.ToSlash(path.Join(uri.Host, uri.Path)) /* use the current working directory as the base for relative uri's */
cwd,err := os.Getwd()
if err != nil { return err }
/* convert uri to an smb-path */
realpath,err := d.toPath(cwd, *uri)
if err != nil { return err }
/* Open up the "\\"-prefixed path using the Windows filesystem */ /* Open up the "\\"-prefixed path using the Windows filesystem */
d.progress = 0 d.progress = 0
d.active = true d.active = true
f, err = os.Open(realpath) f, err := os.Open(realpath)
if err != nil { return err } if err != nil { return err }
defer f.Close() defer f.Close()
// get the file size (at the risk of performance) // get the file size (at the risk of performance)
fi, err := f.Stat() fi, err := f.Stat()
if err != nil { return err } if err != nil { return err }
d.total = fi.Size() d.total = uint64(fi.Size())
// no bufferSize specified, so copy synchronously. // no bufferSize specified, so copy synchronously.
if d.bufferSize == nil { if d.bufferSize == nil {
n,err := io.Copy(dst, f) var n int64
n,err = io.Copy(dst, f)
d.active = false d.active = false
d.progress += n d.progress += uint64(n)
// use a goro in case someone else wants to enable cancel/resume // use a goro in case someone else wants to enable cancel/resume
} else { } else {
errch := make(chan error) errch := make(chan error)
go func(d* SMBDownloader, r *bufio.Reader, w *bufio.Writer, e chan error) { go func(d* SMBDownloader, r io.Reader, w io.Writer, e chan error) {
defer w.Flush() for ; d.active; {
for ; d.active { n,err := io.CopyN(w, r, int64(*d.bufferSize))
n,err := io.CopyN(writer, reader, d.bufferSize)
if err != nil { break } if err != nil { break }
d.progress += n d.progress += uint64(n)
} }
d.active = false d.active = false
e <- err e <- err
}(d, f, bufio.NewWriter(dst), errch) }(d, f, dst, errch)
// ...and as usual we spin until it's done // ...and as usual we spin until it's done
err = <-errch err = <-errch
@ -619,7 +638,3 @@ func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error {
f.Close() f.Close()
return err return err
} }
func (d *SMBDownloader) Total() uint {
return d.total
}