diff --git a/builder/virtualbox/common/step_download_guest_additions.go b/builder/virtualbox/common/step_download_guest_additions.go index facc60b6a..706d5f466 100644 --- a/builder/virtualbox/common/step_download_guest_additions.go +++ b/builder/virtualbox/common/step_download_guest_additions.go @@ -122,7 +122,7 @@ func (s *StepDownloadGuestAdditions) Run(ctx context.Context, state multistep.St } // Convert the file/url to an actual URL for step_download to process. - url, err = common.DownloadableURL(url) + url, err = common.ValidatedURL(url) if err != nil { err := fmt.Errorf("Error preparing guest additions url: %s", err) state.Put("error", err) diff --git a/builder/virtualbox/ovf/config.go b/builder/virtualbox/ovf/config.go index 41a013ebd..5eef4507c 100644 --- a/builder/virtualbox/ovf/config.go +++ b/builder/virtualbox/ovf/config.go @@ -97,7 +97,7 @@ func NewConfig(raws ...interface{}) (*Config, []string, error) { if c.SourcePath == "" { errs = packer.MultiErrorAppend(errs, fmt.Errorf("source_path is required")) } else { - c.SourcePath, err = common.DownloadableURL(c.SourcePath) + c.SourcePath, err = common.ValidatedURL(c.SourcePath) if err != nil { errs = packer.MultiErrorAppend(errs, fmt.Errorf("source_path is invalid: %s", err)) } diff --git a/builder/virtualbox/ovf/config_test.go b/builder/virtualbox/ovf/config_test.go index 5e67b242d..51412a0c5 100644 --- a/builder/virtualbox/ovf/config_test.go +++ b/builder/virtualbox/ovf/config_test.go @@ -84,7 +84,7 @@ func TestNewConfig_sourcePath(t *testing.T) { t.Fatalf("bad: %#v", warns) } if err == nil { - t.Fatal("should error") + t.Fatalf("should error") } // Good diff --git a/common/config.go b/common/config.go index 7bef253ee..aa632ceae 100644 --- a/common/config.go +++ b/common/config.go @@ -5,6 +5,7 @@ import ( "net/url" "os" "path/filepath" + "regexp" "runtime" "strings" "time" @@ -43,89 +44,126 @@ func ChooseString(vals ...string) string { return "" } -// DownloadableURL processes a URL that may also be a file path and returns -// a completely valid URL. For example, the original URL might be "local/file.iso" -// which isn't a valid URL. DownloadableURL will return "file:///local/file.iso" -func DownloadableURL(original string) (string, error) { - if runtime.GOOS == "windows" { - // If the distance to the first ":" is just one character, assume - // we're dealing with a drive letter and thus a file path. - // prepend with "file:///"" now so that url.Parse won't accidentally - // parse the drive letter into the url scheme. - // See https://blogs.msdn.microsoft.com/ie/2006/12/06/file-uris-in-windows/ - // for more info about valid windows URIs - idx := strings.Index(original, ":") - if idx == 1 { - original = "file:///" + original +// SupportedProtocol verifies that the url passed is actually supported or not +// This will also validate that the protocol is one that's actually implemented. +func SupportedProtocol(u *url.URL) bool { + // url.Parse shouldn't return nil except on error....but it can. + if u == nil { + return false + } + + // build a dummy NewDownloadClient since this is the only place that valid + // protocols are actually exposed. + cli := NewDownloadClient(&DownloadConfig{}) + + // Iterate through each downloader to see if a protocol was found. + ok := false + for scheme := range cli.config.DownloaderMap { + if strings.ToLower(u.Scheme) == strings.ToLower(scheme) { + ok = true } } + return ok +} + +// DownloadableURL processes a URL that may also be a file path and returns +// a completely valid URL representing the requested file. For example, +// the original URL might be "local/file.iso" which isn't a valid URL, +// and so DownloadableURL will return "file://local/file.iso" +// No other transformations are done to the path. +func DownloadableURL(original string) (string, error) { + var absPrefix, result string + + absPrefix = "" + if runtime.GOOS == "windows" { + absPrefix = "/" + } + + // Check that the user specified a UNC path, and promote it to an smb:// uri. + if strings.HasPrefix(original, "\\\\") && len(original) > 2 && original[2] != '?' { + result = filepath.ToSlash(original[2:]) + return fmt.Sprintf("smb://%s", result), nil + } + + // Fix the url if it's using bad characters commonly mistaken with a path. + original = filepath.ToSlash(original) + + // Check to see that this is a parseable URL with a scheme and a host. + // If so, then just pass it through. + if u, err := url.Parse(original); err == nil && u.Scheme != "" && u.Host != "" { + return original, nil + } + + // If it's a file scheme, then convert it back to a regular path so the next + // case which forces it to an absolute path, will correct it. + if u, err := url.Parse(original); err == nil && strings.ToLower(u.Scheme) == "file" { + original = u.Path + } + + // If we're on Windows and we start with a slash, then this absolute path + // is wrong. Fix it up, so the next case can figure out the absolute path. + if rpath := strings.SplitN(original, "/", 2); rpath[0] == "" && runtime.GOOS == "windows" { + result = rpath[1] + } else { + result = original + } + + // Since we should be some kind of path (relative or absolute), check + // that the file exists, then make it an absolute path so we can return an + // absolute uri. + if _, err := os.Stat(result); err == nil { + result, err = filepath.Abs(filepath.FromSlash(result)) + if err != nil { + return "", err + } + + result, err = filepath.EvalSymlinks(result) + if err != nil { + return "", err + } + + result = filepath.Clean(result) + return fmt.Sprintf("file://%s%s", absPrefix, filepath.ToSlash(result)), nil + } + + // Otherwise, check if it was originally an absolute path, and fix it if so. + if strings.HasPrefix(original, "/") { + return fmt.Sprintf("file://%s%s", absPrefix, result), nil + } + + // Anything left should be a non-existent relative path. So fix it up here. + result = filepath.ToSlash(filepath.Clean(result)) + return fmt.Sprintf("file://./%s", result), nil +} + +// Force the parameter into a url. This will transform the parameter into +// a proper url, removing slashes, adding the proper prefix, etc. +func ValidatedURL(original string) (string, error) { + + // See if the user failed to give a url + if ok, _ := regexp.MatchString("(?m)^[^[:punct:]]+://", original); !ok { + + // So since no magic was found, this must be a path. + result, err := DownloadableURL(original) + if err == nil { + return ValidatedURL(result) + } + + return "", err + } + + // Verify that the url is parseable...just in case. u, err := url.Parse(original) if err != nil { return "", err } - if u.Scheme == "" { - u.Scheme = "file" + // We should now have a url, so verify that it's a protocol we support. + if !SupportedProtocol(u) { + return "", fmt.Errorf("Unsupported protocol scheme! (%#v)", u) } - if u.Scheme == "file" { - // Windows file handling is all sorts of tricky... - if runtime.GOOS == "windows" { - // If the path is using Windows-style slashes, URL parses - // it into the host field. - if u.Path == "" && strings.Contains(u.Host, `\`) { - u.Path = u.Host - u.Host = "" - } - } - // Only do the filepath transformations if the file appears - // to actually exist. - if _, err := os.Stat(u.Path); err == nil { - u.Path, err = filepath.Abs(u.Path) - if err != nil { - return "", err - } - - u.Path, err = filepath.EvalSymlinks(u.Path) - if err != nil { - return "", err - } - - u.Path = filepath.Clean(u.Path) - } - - if runtime.GOOS == "windows" { - // Also replace all backslashes with forwardslashes since Windows - // users are likely to do this but the URL should actually only - // contain forward slashes. - u.Path = strings.Replace(u.Path, `\`, `/`, -1) - // prepend absolute windows paths with "/" so that when we - // compose u.String() below the outcome will be correct - // file:///c/blah syntax; otherwise u.String() will only add - // file:// which is not technically a correct windows URI - if filepath.IsAbs(u.Path) && !strings.HasPrefix(u.Path, "/") { - u.Path = "/" + u.Path - } - - } - } - - // Make sure it is lowercased - u.Scheme = strings.ToLower(u.Scheme) - - // Verify that the scheme is something we support in our common downloader. - supported := []string{"file", "http", "https"} - found := false - for _, s := range supported { - if u.Scheme == s { - found = true - break - } - } - - if !found { - return "", fmt.Errorf("Unsupported URL scheme: %s", u.Scheme) - } + // We should now have a properly formatted and supported url return u.String(), nil } @@ -144,27 +182,38 @@ func DownloadableURL(original string) (string, error) { func FileExistsLocally(original string) bool { // original should be something like file://C:/my/path.iso + u, _ := url.Parse(original) - fileURL, _ := url.Parse(original) - fileExists := false - - if fileURL.Scheme == "file" { - // on windows, correct URI is file:///c:/blah/blah.iso. - // url.Parse will pull out the scheme "file://" and leave the path as - // "/c:/blah/blah/iso". Here we remove this forward slash on absolute - // Windows file URLs before processing - // see https://blogs.msdn.microsoft.com/ie/2006/12/06/file-uris-in-windows/ - // for more info about valid windows URIs - filePath := fileURL.Path - if runtime.GOOS == "windows" && len(filePath) > 0 && filePath[0] == '/' { - filePath = filePath[1:] - } - _, err := os.Stat(filePath) - if err != nil { - return fileExists - } else { - fileExists = true - } + // First create a dummy downloader so we can figure out which + // protocol to use. + cli := NewDownloadClient(&DownloadConfig{}) + d, ok := cli.config.DownloaderMap[u.Scheme] + if !ok { + return false } - return fileExists + + // Check to see that it's got a Local way of doing things. + local, ok := d.(LocalDownloader) + if !ok { + return true // XXX: Remote URLs short-circuit this logic. + } + + // Figure out where we're at. + wd, err := os.Getwd() + if err != nil { + return false + } + + // Now figure out the real path to the file. + realpath, err := local.toPath(wd, *u) + if err != nil { + return false + } + + // Finally we can seek the truth via os.Stat. + _, err = os.Stat(realpath) + if err != nil { + return false + } + return true } diff --git a/common/config_test.go b/common/config_test.go index 365515109..8f170fe1d 100644 --- a/common/config_test.go +++ b/common/config_test.go @@ -37,7 +37,28 @@ func TestChooseString(t *testing.T) { } } -func TestDownloadableURL(t *testing.T) { +func TestValidatedURL(t *testing.T) { + // Invalid URL: has hex code in host + _, err := ValidatedURL("http://what%20.com") + if err == nil { + t.Fatalf("expected err : %s", err) + } + + // Invalid: unsupported scheme + _, err = ValidatedURL("ftp://host.com/path") + if err == nil { + t.Fatalf("expected err : %s", err) + } + + // Valid: http + u, err := ValidatedURL("HTTP://packer.io/path") + if err != nil { + t.Fatalf("err: %s", err) + } + + if u != "http://packer.io/path" { + t.Fatalf("bad: %s", u) + } cases := []struct { InputString string @@ -55,7 +76,7 @@ func TestDownloadableURL(t *testing.T) { } for _, tc := range cases { - u, err := DownloadableURL(tc.InputString) + u, err := ValidatedURL(tc.InputString) if u != tc.OutputURL { t.Fatal(fmt.Sprintf("Error with URL %s: got %s but expected %s", tc.InputString, tc.OutputURL, u)) @@ -63,78 +84,127 @@ func TestDownloadableURL(t *testing.T) { if (err != nil) != tc.ErrExpected { if tc.ErrExpected == true { t.Fatal(fmt.Sprintf("Error with URL %s: we expected "+ - "DownloadableURL to return an error but didn't get one.", + "ValidatedURL to return an error but didn't get one.", tc.InputString)) } else { t.Fatal(fmt.Sprintf("Error with URL %s: we did not expect an "+ - " error from DownloadableURL but we got: %s", + " error from ValidatedURL but we got: %s", tc.InputString, err)) } } } } +func GetNativePathToTestFixtures(t *testing.T) string { + const path = "./test-fixtures" + res, err := filepath.Abs(path) + if err != nil { + t.Fatalf("err converting test-fixtures path into an absolute path : %s", err) + } + return res +} + +func GetPortablePathToTestFixtures(t *testing.T) string { + res := GetNativePathToTestFixtures(t) + return filepath.ToSlash(res) +} + func TestDownloadableURL_WindowsFiles(t *testing.T) { if runtime.GOOS == "windows" { + portablepath := GetPortablePathToTestFixtures(t) + nativepath := GetNativePathToTestFixtures(t) + dirCases := []struct { InputString string OutputURL string ErrExpected bool }{ // TODO: add different directories { - "C:\\Temp\\SomeDir\\myfile.txt", - "file:///C:/Temp/SomeDir/myfile.txt", + fmt.Sprintf("%s\\SomeDir\\myfile.txt", nativepath), + fmt.Sprintf("file:///%s/SomeDir/myfile.txt", portablepath), false, }, - { // need windows drive - "\\Temp\\SomeDir\\myfile.txt", - "", - true, - }, - { // need windows drive - "/Temp/SomeDir/myfile.txt", - "", - true, - }, - { // UNC paths; why not? - "\\\\?\\c:\\Temp\\SomeDir\\myfile.txt", - "", - true, - }, - { - "file:///C:\\Temp\\SomeDir\\myfile.txt", - "file:///c:/Temp/SomeDir/myfile.txt", + { // without the drive makes this native path a relative file:// uri + "test-fixtures\\SomeDir\\myfile.txt", + fmt.Sprintf("file:///%s/SomeDir/myfile.txt", portablepath), false, }, - { - "file:///c:/Temp/Somedir/myfile.txt", - "file:///c:/Temp/SomeDir/myfile.txt", + { // without the drive makes this native path a relative file:// uri + "test-fixtures/SomeDir/myfile.txt", + fmt.Sprintf("file:///%s/SomeDir/myfile.txt", portablepath), + false, + }, + { // UNC paths being promoted to smb:// uri scheme. + fmt.Sprintf("\\\\localhost\\C$\\%s\\SomeDir\\myfile.txt", nativepath), + fmt.Sprintf("smb://localhost/C$/%s/SomeDir/myfile.txt", portablepath), + false, + }, + { // Absolute uri (incorrect slash type) + fmt.Sprintf("file:///%s\\SomeDir\\myfile.txt", nativepath), + fmt.Sprintf("file:///%s/SomeDir/myfile.txt", portablepath), + false, + }, + { // Absolute uri (existing and mis-spelled) + fmt.Sprintf("file:///%s/Somedir/myfile.txt", nativepath), + fmt.Sprintf("file:///%s/SomeDir/myfile.txt", portablepath), + false, + }, + { // Absolute path (non-existing) + "\\absolute\\path\\to\\non-existing\\file.txt", + "file:///absolute/path/to/non-existing/file.txt", + false, + }, + { // Absolute paths (existing) + fmt.Sprintf("%s/SomeDir/myfile.txt", nativepath), + fmt.Sprintf("file:///%s/SomeDir/myfile.txt", portablepath), + false, + }, + { // Relative path (non-existing) + "./nonexisting/relative/path/to/file.txt", + "file://./nonexisting/relative/path/to/file.txt", + false, + }, + { // Relative path (existing) + "./test-fixtures/SomeDir/myfile.txt", + fmt.Sprintf("file:///%s/SomeDir/myfile.txt", portablepath), + false, + }, + { // Absolute uri (existing and with `/` prefix) + fmt.Sprintf("file:///%s/SomeDir/myfile.txt", portablepath), + fmt.Sprintf("file:///%s/SomeDir/myfile.txt", portablepath), + false, + }, + { // Absolute uri (non-existing and with `/` prefix) + "file:///path/to/non-existing/file.txt", + "file:///path/to/non-existing/file.txt", + false, + }, + { // Absolute uri (non-existing and missing `/` prefix) + "file://path/to/non-existing/file.txt", + "file://path/to/non-existing/file.txt", + false, + }, + { // Absolute uri and volume (non-existing and with `/` prefix) + "file:///T:/path/to/non-existing/file.txt", + "file:///T:/path/to/non-existing/file.txt", + false, + }, + { // Absolute uri and volume (non-existing and missing `/` prefix) + "file://T:/path/to/non-existing/file.txt", + "file://T:/path/to/non-existing/file.txt", false, }, } - // create absolute-pathed tempfile to play with - err := os.Mkdir("C:\\Temp\\SomeDir", 0755) - if err != nil { - t.Fatalf("err creating test dir: %s", err) - } - fi, err := os.Create("C:\\Temp\\SomeDir\\myfile.txt") - if err != nil { - t.Fatalf("err creating test file: %s", err) - } - fi.Close() - defer os.Remove("C:\\Temp\\SomeDir\\myfile.txt") - defer os.Remove("C:\\Temp\\SomeDir") - // Run through test cases to make sure they all parse correctly - for _, tc := range dirCases { + for idx, tc := range dirCases { u, err := DownloadableURL(tc.InputString) if (err != nil) != tc.ErrExpected { - t.Fatalf("Test Case failed: Expected err = %#v, err = %#v, input = %s", - tc.ErrExpected, err, tc.InputString) + t.Fatalf("Test Case %d failed: Expected err = %#v, err = %#v, input = %s", + idx, tc.ErrExpected, err, tc.InputString) } if u != tc.OutputURL { - t.Fatalf("Test Case failed: Expected %s but received %s from input %s", - tc.OutputURL, u, tc.InputString) + t.Fatalf("Test Case %d failed: Expected %s but received %s from input %s", + idx, tc.OutputURL, u, tc.InputString) } } } @@ -154,10 +224,12 @@ func TestDownloadableURL_FilePaths(t *testing.T) { } tfPath = filepath.Clean(tfPath) - filePrefix := "file://" + + // If we're running windows, then absolute URIs are `/`-prefixed. + platformPrefix := "" if runtime.GOOS == "windows" { - filePrefix += "/" + platformPrefix = "/" } // Relative filepath. We run this test in a func so that @@ -180,8 +252,9 @@ func TestDownloadableURL_FilePaths(t *testing.T) { t.Fatalf("err: %s", err) } - expected := fmt.Sprintf("%s%s", + expected := fmt.Sprintf("%s%s%s", filePrefix, + platformPrefix, strings.Replace(tfPath, `\`, `/`, -1)) if u != expected { t.Fatalf("unexpected: %#v != %#v", u, expected) @@ -189,21 +262,22 @@ func TestDownloadableURL_FilePaths(t *testing.T) { }() // Test some cases with and without a schema prefix - for _, prefix := range []string{"", filePrefix} { + for _, prefix := range []string{"", filePrefix + platformPrefix} { // Nonexistent file _, err = DownloadableURL(prefix + "i/dont/exist") if err != nil { t.Fatalf("err: %s", err) } - // Good file + // Good file (absolute) u, err := DownloadableURL(prefix + tfPath) if err != nil { t.Fatalf("err: %s", err) } - expected := fmt.Sprintf("%s%s", + expected := fmt.Sprintf("%s%s%s", filePrefix, + platformPrefix, strings.Replace(tfPath, `\`, `/`, -1)) if u != expected { t.Fatalf("unexpected: %s != %s", u, expected) @@ -211,39 +285,28 @@ func TestDownloadableURL_FilePaths(t *testing.T) { } } -func test_FileExistsLocally(t *testing.T) { - if runtime.GOOS == "windows" { - dirCases := []struct { - Input string - Output bool - }{ - // file exists locally - {"file:///C:/Temp/SomeDir/myfile.txt", true}, - // file is not supposed to exist locally - {"https://myfile.iso", true}, - // file does not exist locally - {"file:///C/i/dont/exist", false}, - } - // create absolute-pathed tempfile to play with - err := os.Mkdir("C:\\Temp\\SomeDir", 0755) - if err != nil { - t.Fatalf("err creating test dir: %s", err) - } - fi, err := os.Create("C:\\Temp\\SomeDir\\myfile.txt") - if err != nil { - t.Fatalf("err creating test file: %s", err) - } - fi.Close() - defer os.Remove("C:\\Temp\\SomeDir\\myfile.txt") - defer os.Remove("C:\\Temp\\SomeDir") +func TestFileExistsLocally(t *testing.T) { + portablepath := GetPortablePathToTestFixtures(t) - // Run through test cases to make sure they all parse correctly - for _, tc := range dirCases { - fileOK := FileExistsLocally(tc.Input) - if !fileOK { - t.Fatalf("Test Case failed: Expected %#v, received = %#v, input = %s", - tc.Output, fileOK, tc.Input) - } + dirCases := []struct { + Input string + Output bool + }{ + // file exists locally + {fmt.Sprintf("file://%s/SomeDir/myfile.txt", portablepath), true}, + // remote protocols short-circuit and are considered to exist locally + {"https://myfile.iso", true}, + // non-existent protocols do not exist and hence fail + {"nonexistent-protocol://myfile.iso", false}, + // file does not exist locally + {"file:///C/i/dont/exist", false}, + } + // Run through test cases to make sure they all parse correctly + for _, tc := range dirCases { + fileOK := FileExistsLocally(tc.Input) + if fileOK != tc.Output { + t.Fatalf("Test Case failed: Expected %#v, received = %#v, input = %s", + tc.Output, fileOK, tc.Input) } } } diff --git a/common/download.go b/common/download.go index ada4c9756..6d3d2a685 100644 --- a/common/download.go +++ b/common/download.go @@ -10,12 +10,19 @@ import ( "errors" "fmt" "hash" - "io" "log" - "net/http" "net/url" "os" + "path" "runtime" + "strings" +) + +// imports related to each Downloader implementation +import ( + "io" + "net/http" + "path/filepath" ) // DownloadConfig is the configuration given to instantiate a new @@ -75,23 +82,38 @@ func HashForType(t string) hash.Hash { // NewDownloadClient returns a new DownloadClient for the given // configuration. func NewDownloadClient(c *DownloadConfig) *DownloadClient { + const mtu = 1500 /* ethernet */ - 20 /* ipv4 */ - 20 /* tcp */ + + // Create downloader map if it hasn't been specified already. if c.DownloaderMap == nil { c.DownloaderMap = map[string]Downloader{ + "file": &FileDownloader{bufferSize: nil}, "http": &HTTPDownloader{userAgent: c.UserAgent}, "https": &HTTPDownloader{userAgent: c.UserAgent}, + "smb": &SMBDownloader{bufferSize: nil}, } } - return &DownloadClient{config: c} } -// A downloader is responsible for actually taking a remote URL and -// downloading it. +// A downloader implements the ability to transfer, cancel, or resume a file. type Downloader interface { + Resume() Cancel() + Progress() uint64 + Total() uint64 +} + +// A LocalDownloader is responsible for converting a uri to a local path +// that the platform can open directly. +type LocalDownloader interface { + toPath(string, url.URL) (string, error) +} + +// A RemoteDownloader is responsible for actually taking a remote URL and +// downloading it. +type RemoteDownloader interface { Download(*os.File, *url.URL) error - Progress() uint - Total() uint } func (d *DownloadClient) Cancel() { @@ -105,62 +127,64 @@ func (d *DownloadClient) Get() (string, error) { return d.config.TargetPath, nil } + /* parse the configuration url into a net/url object */ u, err := url.Parse(d.config.Url) if err != nil { return "", err } - log.Printf("Parsed URL: %#v", u) - // Files when we don't copy the file are special cased. - var f *os.File + /* use the current working directory as the base for relative uri's */ + cwd, err := os.Getwd() + if err != nil { + return "", err + } + + // Determine which is the correct downloader to use var finalPath string - sourcePath := "" - if u.Scheme == "file" && !d.config.CopyFile { - // This is special case for relative path in this case user specify - // file:../ and after parse destination goes to Opaque - if u.Path != "" { - // If url.Path is set just use this - finalPath = u.Path - } else if u.Opaque != "" { - // otherwise try url.Opaque - finalPath = u.Opaque - } - // This is a special case where we use a source file that already exists - // locally and we don't make a copy. Normally we would copy or download. - log.Printf("[DEBUG] Using local file: %s", finalPath) - // Remove forward slash on absolute Windows file URLs before processing - if runtime.GOOS == "windows" && len(finalPath) > 0 && finalPath[0] == '/' { - finalPath = finalPath[1:] - } + var ok bool + d.downloader, ok = d.config.DownloaderMap[u.Scheme] + if !ok { + return "", fmt.Errorf("No downloader for scheme: %s", u.Scheme) + } - // Keep track of the source so we can make sure not to delete this later - sourcePath = finalPath - if _, err = os.Stat(finalPath); err != nil { - return "", err - } - } else { + remote, ok := d.downloader.(RemoteDownloader) + if !ok { + return "", fmt.Errorf("Unable to treat uri scheme %s as a Downloader. : %T", u.Scheme, d.downloader) + } + + local, ok := d.downloader.(LocalDownloader) + if !ok && !d.config.CopyFile { + d.config.CopyFile = true + } + + // If we're copying the file, then just use the actual downloader + if d.config.CopyFile { + var f *os.File finalPath = d.config.TargetPath - var ok bool - d.downloader, ok = d.config.DownloaderMap[u.Scheme] - if !ok { - return "", fmt.Errorf("No downloader for scheme: %s", u.Scheme) - } - - // Otherwise, download using the downloader. f, err = os.OpenFile(finalPath, os.O_RDWR|os.O_CREATE, os.FileMode(0666)) if err != nil { return "", err } log.Printf("[DEBUG] Downloading: %s", u.String()) - err = d.downloader.Download(f, u) + err = remote.Download(f, u) f.Close() if err != nil { return "", err } + + // Otherwise if our Downloader is a LocalDownloader we can just use the + // path after transforming it. + } else { + finalPath, err = local.toPath(cwd, *u) + if err != nil { + return "", err + } + + log.Printf("[DEBUG] Using local file: %s", finalPath) } if d.config.Hash != nil { @@ -168,7 +192,7 @@ func (d *DownloadClient) Get() (string, error) { verify, err = d.VerifyChecksum(finalPath) if err == nil && !verify { // Only delete the file if we made a copy or downloaded it - if sourcePath != finalPath { + if d.config.CopyFile { os.Remove(finalPath) } @@ -181,7 +205,6 @@ func (d *DownloadClient) Get() (string, error) { return finalPath, err } -// PercentProgress returns the download progress as a percentage. func (d *DownloadClient) PercentProgress() int { if d.downloader == nil { return -1 @@ -212,17 +235,21 @@ func (d *DownloadClient) VerifyChecksum(path string) (bool, error) { // HTTPDownloader is an implementation of Downloader that downloads // files over HTTP. type HTTPDownloader struct { - progress uint - total uint + current uint64 + total uint64 userAgent string } -func (*HTTPDownloader) Cancel() { +func (d *HTTPDownloader) Cancel() { + // TODO(mitchellh): Implement +} + +func (d *HTTPDownloader) Resume() { // TODO(mitchellh): Implement } func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { - log.Printf("Starting download: %s", src.String()) + log.Printf("Starting download over HTTP: %s", src.String()) // Seek to the beginning by default if _, err := dst.Seek(0, 0); err != nil { @@ -230,7 +257,7 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { } // Reset our progress - d.progress = 0 + d.current = 0 // Make the request. We first make a HEAD request so we can check // if the server supports range queries. If the server/URL doesn't @@ -258,7 +285,8 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { if fi, err := dst.Stat(); err == nil { if _, err = dst.Seek(0, os.SEEK_END); err == nil { req.Header.Set("Range", fmt.Sprintf("bytes=%d-", fi.Size())) - d.progress = uint(fi.Size()) + + d.current = uint64(fi.Size()) } } } @@ -272,7 +300,8 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { return err } - d.total = d.progress + uint(resp.ContentLength) + d.total = d.current + uint64(resp.ContentLength) + var buffer [4096]byte for { n, err := resp.Body.Read(buffer[:]) @@ -280,7 +309,7 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { return err } - d.progress += uint(n) + d.current += uint64(n) if _, werr := dst.Write(buffer[:n]); werr != nil { return werr @@ -290,14 +319,253 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { break } } - return nil } -func (d *HTTPDownloader) Progress() uint { - return d.progress +func (d *HTTPDownloader) Progress() uint64 { + return d.current } -func (d *HTTPDownloader) Total() uint { +func (d *HTTPDownloader) Total() uint64 { return d.total } + +// FileDownloader is an implementation of Downloader that downloads +// files using the regular filesystem. +type FileDownloader struct { + bufferSize *uint + + active bool + current uint64 + total uint64 +} + +func (d *FileDownloader) Progress() uint64 { + return d.current +} + +func (d *FileDownloader) Total() uint64 { + return d.total +} + +func (d *FileDownloader) Cancel() { + d.active = false +} + +func (d *FileDownloader) Resume() { + // TODO: Implement +} + +func (d *FileDownloader) toPath(base string, uri url.URL) (string, error) { + var result string + + // absolute path -- file://c:/absolute/path -> c:/absolute/path + if strings.HasSuffix(uri.Host, ":") { + result = path.Join(uri.Host, uri.Path) + + // semi-absolute path (current drive letter) + // -- file:///absolute/path -> drive:/absolute/path + } else if uri.Host == "" && strings.HasPrefix(uri.Path, "/") { + apath := uri.Path + components := strings.Split(apath, "/") + volume := filepath.VolumeName(base) + + // semi-absolute absolute path (includes volume letter) + // -- file://drive:/path -> drive:/absolute/path + if len(components) > 1 && strings.HasSuffix(components[1], ":") { + volume = components[1] + apath = path.Join(components[2:]...) + } + + result = path.Join(volume, apath) + + // relative path -- file://./relative/path -> ./relative/path + } else if uri.Host == "." { + result = path.Join(base, uri.Path) + + // relative path -- file://relative/path -> ./relative/path + } else { + result = path.Join(base, uri.Host, uri.Path) + } + return filepath.ToSlash(result), nil +} + +func (d *FileDownloader) Download(dst *os.File, src *url.URL) error { + d.active = false + + /* check the uri's scheme to make sure it matches */ + if src == nil || src.Scheme != "file" { + return fmt.Errorf("Unexpected uri scheme: %s", src.Scheme) + } + uri := src + + /* use the current working directory as the base for relative uri's */ + cwd, err := os.Getwd() + if err != nil { + return err + } + + /* determine which uri format is being used and convert to a real path */ + realpath, err := d.toPath(cwd, *uri) + if err != nil { + return err + } + + /* download the file using the operating system's facilities */ + d.current = 0 + d.active = true + + f, err := os.Open(realpath) + if err != nil { + return err + } + defer f.Close() + + // get the file size + fi, err := f.Stat() + if err != nil { + return err + } + d.total = uint64(fi.Size()) + + // no bufferSize specified, so copy synchronously. + if d.bufferSize == nil { + var n int64 + n, err = io.Copy(dst, f) + d.active = false + + d.current += uint64(n) + + // use a goro in case someone else wants to enable cancel/resume + } else { + errch := make(chan error) + go func(d *FileDownloader, r io.Reader, w io.Writer, e chan error) { + for d.active { + n, err := io.CopyN(w, r, int64(*d.bufferSize)) + if err != nil { + break + } + + d.current += uint64(n) + } + d.active = false + e <- err + }(d, f, dst, errch) + + // ...and we spin until it's done + err = <-errch + } + f.Close() + return err +} + +// SMBDownloader is an implementation of Downloader that downloads +// files using the "\\" path format on Windows +type SMBDownloader struct { + bufferSize *uint + + active bool + current uint64 + total uint64 +} + +func (d *SMBDownloader) Progress() uint64 { + return d.current +} + +func (d *SMBDownloader) Total() uint64 { + return d.total +} + +func (d *SMBDownloader) Cancel() { + d.active = false +} + +func (d *SMBDownloader) Resume() { + // TODO: Implement +} + +func (d *SMBDownloader) toPath(base string, uri url.URL) (string, error) { + const UNCPrefix = string(os.PathSeparator) + string(os.PathSeparator) + + if runtime.GOOS != "windows" { + return "", fmt.Errorf("Support for SMB based uri's are not supported on %s", runtime.GOOS) + } + + return UNCPrefix + filepath.ToSlash(path.Join(uri.Host, uri.Path)), nil +} + +func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error { + + /* first we warn the world if we're not running windows */ + if runtime.GOOS != "windows" { + return fmt.Errorf("Support for SMB based uri's are not supported on %s", runtime.GOOS) + } + + d.active = false + + /* convert the uri using the net/url module to a UNC path */ + if src == nil || src.Scheme != "smb" { + return fmt.Errorf("Unexpected uri scheme: %s", src.Scheme) + } + uri := src + + /* use the current working directory as the base for relative uri's */ + cwd, err := os.Getwd() + if err != nil { + return err + } + + /* convert uri to an smb-path */ + realpath, err := d.toPath(cwd, *uri) + if err != nil { + return err + } + + /* Open up the "\\"-prefixed path using the Windows filesystem */ + d.current = 0 + d.active = true + + f, err := os.Open(realpath) + if err != nil { + return err + } + defer f.Close() + + // get the file size (at the risk of performance) + fi, err := f.Stat() + if err != nil { + return err + } + d.total = uint64(fi.Size()) + + // no bufferSize specified, so copy synchronously. + if d.bufferSize == nil { + var n int64 + n, err = io.Copy(dst, f) + d.active = false + + d.current += uint64(n) + + // use a goro in case someone else wants to enable cancel/resume + } else { + errch := make(chan error) + go func(d *SMBDownloader, r io.Reader, w io.Writer, e chan error) { + for d.active { + n, err := io.CopyN(w, r, int64(*d.bufferSize)) + if err != nil { + break + } + + d.current += uint64(n) + } + d.active = false + e <- err + }(d, f, dst, errch) + + // ...and as usual we spin until it's done + err = <-errch + } + f.Close() + return err +} diff --git a/common/download_test.go b/common/download_test.go index 497a05649..beac9cbab 100644 --- a/common/download_test.go +++ b/common/download_test.go @@ -8,7 +8,9 @@ import ( "net/http" "net/http/httptest" "os" + "path/filepath" "runtime" + "strings" "testing" ) @@ -56,6 +58,7 @@ func TestDownloadClient_basic(t *testing.T) { client := NewDownloadClient(&DownloadConfig{ Url: ts.URL + "/basic.txt", TargetPath: tf.Name(), + CopyFile: true, }) path, err := client.Get() @@ -91,6 +94,7 @@ func TestDownloadClient_checksumBad(t *testing.T) { TargetPath: tf.Name(), Hash: HashForType("md5"), Checksum: checksum, + CopyFile: true, }) if _, err := client.Get(); err == nil { t.Fatal("should error") @@ -115,6 +119,7 @@ func TestDownloadClient_checksumGood(t *testing.T) { TargetPath: tf.Name(), Hash: HashForType("md5"), Checksum: checksum, + CopyFile: true, }) path, err := client.Get() if err != nil { @@ -145,6 +150,7 @@ func TestDownloadClient_checksumNoDownload(t *testing.T) { TargetPath: "./test-fixtures/root/another.txt", Hash: HashForType("md5"), Checksum: checksum, + CopyFile: true, }) path, err := client.Get() if err != nil { @@ -183,6 +189,7 @@ func TestDownloadClient_resume(t *testing.T) { client := NewDownloadClient(&DownloadConfig{ Url: ts.URL, TargetPath: tf.Name(), + CopyFile: true, }) path, err := client.Get() if err != nil { @@ -240,6 +247,7 @@ func TestDownloadClient_usesDefaultUserAgent(t *testing.T) { config := &DownloadConfig{ Url: server.URL, TargetPath: tf.Name(), + CopyFile: true, } client := NewDownloadClient(config) @@ -271,6 +279,7 @@ func TestDownloadClient_setsUserAgent(t *testing.T) { Url: server.URL, TargetPath: tf.Name(), UserAgent: "fancy user agent", + CopyFile: true, } client := NewDownloadClient(config) @@ -351,6 +360,7 @@ func TestDownloadFileUrl(t *testing.T) { if err != nil { t.Fatalf("Unable to detect working directory: %s", err) } + cwd = filepath.ToSlash(cwd) // source_path is a file path and source is a network path sourcePath := fmt.Sprintf("%s/test-fixtures/fileurl/%s", cwd, "cake") @@ -376,11 +386,116 @@ func TestDownloadFileUrl(t *testing.T) { // Verify that we fail to match the checksum _, err = client.Get() if err.Error() != "checksums didn't match expected: 6e6f7065" { - t.Fatalf("Unexpected failure; expected checksum not to match") + t.Fatalf("Unexpected failure; expected checksum not to match. Error was \"%v\"", err) } if _, err = os.Stat(sourcePath); err != nil { t.Errorf("Could not stat source file: %s", sourcePath) } - +} + +// SimulateFileUriDownload is a simple utility function that converts a uri +// into a testable file path whilst ignoring a correct checksum match, stripping +// UNC path info, and then calling stat to ensure the correct file exists. +// (used by TestFileUriTransforms) +func SimulateFileUriDownload(t *testing.T, uri string) (string, error) { + // source_path is a file path and source is a network path + source := fmt.Sprintf(uri) + t.Logf("Trying to download %s", source) + + config := &DownloadConfig{ + Url: source, + // This should be wrong. We want to make sure we don't delete + Checksum: []byte("nope"), + Hash: HashForType("sha256"), + CopyFile: false, + } + + // go go go + client := NewDownloadClient(config) + path, err := client.Get() + + // ignore any non-important checksum errors if it's not a unc path + if !strings.HasPrefix(path, "\\\\") && err.Error() != "checksums didn't match expected: 6e6f7065" { + t.Fatalf("Unexpected failure; expected checksum not to match") + } + + // if it's a unc path, then remove the host and share name so we don't have + // to force the user to enable ADMIN$ and Windows File Sharing + if strings.HasPrefix(path, "\\\\") { + res := strings.SplitN(path, "/", 3) + path = "/" + res[2] + } + + if _, err = os.Stat(path); err != nil { + t.Errorf("Could not stat source file: %s", path) + } + return path, err +} + +// TestFileUriTransforms tests the case where we use a local file uri +// for iso_url. There's a few different formats that a file uri can exist as +// and so we try to test the most useful and common ones. +func TestFileUriTransforms(t *testing.T) { + const testpath = /* have your */ "test-fixtures/fileurl/cake" /* and eat it too */ + const host = "localhost" + + var cwd string + var volume string + var share string + + cwd, err := os.Getwd() + if err != nil { + t.Fatalf("Unable to detect working directory: %s", err) + return + } + cwd = filepath.ToSlash(cwd) + volume = filepath.VolumeName(cwd) + share = volume + + // if a volume was found (on windows), replace the ':' from + // C: to C$ to convert it into a hidden windows share. + if len(share) > 1 && share[len(share)-1] == ':' { + share = share[:len(share)-1] + "$" + } + cwd = cwd[len(volume):] + + t.Logf("TestFileUriTransforms : Running with cwd : '%s'", cwd) + t.Logf("TestFileUriTransforms : Running with volume : '%s'", volume) + + // ./relative/path -> ./relative/path + // /absolute/path -> /absolute/path + // c:/windows/absolute -> c:/windows/absolute + testcases := []string{ + "./%s", + cwd + "/%s", + volume + cwd + "/%s", + } + + // all regular slashed testcases + for _, testcase := range testcases { + uri := "file://" + fmt.Sprintf(testcase, testpath) + t.Logf("TestFileUriTransforms : Trying Uri '%s'", uri) + res, err := SimulateFileUriDownload(t, uri) + if err != nil { + t.Errorf("Unable to transform uri '%s' into a path : %v", uri, err) + } + t.Logf("TestFileUriTransforms : Result Path '%s'", res) + } + + // smb protocol depends on platform support which currently + // only exists on windows. + if runtime.GOOS == "windows" { + // ...and finally the oddball windows native path + // smb://host/sharename/file -> \\host\sharename\file + testcase := host + "/" + share + "/" + cwd[1:] + "/%s" + uri := "smb://" + fmt.Sprintf(testcase, testpath) + t.Logf("TestFileUriTransforms : Trying Uri '%s'", uri) + res, err := SimulateFileUriDownload(t, uri) + if err != nil { + t.Errorf("Unable to transform uri '%s' into a path", uri) + return + } + t.Logf("TestFileUriTransforms : Result Path '%s'", res) + } } diff --git a/common/iso_config.go b/common/iso_config.go index 5216dc458..e5a1e4e90 100644 --- a/common/iso_config.go +++ b/common/iso_config.go @@ -111,7 +111,7 @@ func (c *ISOConfig) Prepare(ctx *interpolate.Context) (warnings []string, errs [ c.ISOChecksum = strings.ToLower(c.ISOChecksum) for i, url := range c.ISOUrls { - url, err := DownloadableURL(url) + url, err := ValidatedURL(url) if err != nil { errs = append( errs, fmt.Errorf("Failed to parse iso_url %d: %s", i+1, err)) diff --git a/common/test-fixtures/SomeDir/myfile.txt b/common/test-fixtures/SomeDir/myfile.txt new file mode 100644 index 000000000..e69de29bb