Merge pull request #2906 from arizvisa/GH-2377

Improved support for downloading and validating a uri containing a Windows UNC path or a relative file:// scheme
This commit is contained in:
SwampDragons 2018-02-05 09:53:47 -08:00 committed by GitHub
commit 7d5d62d748
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 735 additions and 240 deletions

View File

@ -122,7 +122,7 @@ func (s *StepDownloadGuestAdditions) Run(ctx context.Context, state multistep.St
} }
// Convert the file/url to an actual URL for step_download to process. // Convert the file/url to an actual URL for step_download to process.
url, err = common.DownloadableURL(url) url, err = common.ValidatedURL(url)
if err != nil { if err != nil {
err := fmt.Errorf("Error preparing guest additions url: %s", err) err := fmt.Errorf("Error preparing guest additions url: %s", err)
state.Put("error", err) state.Put("error", err)

View File

@ -97,7 +97,7 @@ func NewConfig(raws ...interface{}) (*Config, []string, error) {
if c.SourcePath == "" { if c.SourcePath == "" {
errs = packer.MultiErrorAppend(errs, fmt.Errorf("source_path is required")) errs = packer.MultiErrorAppend(errs, fmt.Errorf("source_path is required"))
} else { } else {
c.SourcePath, err = common.DownloadableURL(c.SourcePath) c.SourcePath, err = common.ValidatedURL(c.SourcePath)
if err != nil { if err != nil {
errs = packer.MultiErrorAppend(errs, fmt.Errorf("source_path is invalid: %s", err)) errs = packer.MultiErrorAppend(errs, fmt.Errorf("source_path is invalid: %s", err))
} }

View File

@ -84,7 +84,7 @@ func TestNewConfig_sourcePath(t *testing.T) {
t.Fatalf("bad: %#v", warns) t.Fatalf("bad: %#v", warns)
} }
if err == nil { if err == nil {
t.Fatal("should error") t.Fatalf("should error")
} }
// Good // Good

View File

@ -5,6 +5,7 @@ import (
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"regexp"
"runtime" "runtime"
"strings" "strings"
"time" "time"
@ -43,89 +44,126 @@ func ChooseString(vals ...string) string {
return "" return ""
} }
// DownloadableURL processes a URL that may also be a file path and returns // SupportedProtocol verifies that the url passed is actually supported or not
// a completely valid URL. For example, the original URL might be "local/file.iso" // This will also validate that the protocol is one that's actually implemented.
// which isn't a valid URL. DownloadableURL will return "file:///local/file.iso" func SupportedProtocol(u *url.URL) bool {
func DownloadableURL(original string) (string, error) { // url.Parse shouldn't return nil except on error....but it can.
if runtime.GOOS == "windows" { if u == nil {
// If the distance to the first ":" is just one character, assume return false
// we're dealing with a drive letter and thus a file path. }
// prepend with "file:///"" now so that url.Parse won't accidentally
// parse the drive letter into the url scheme. // build a dummy NewDownloadClient since this is the only place that valid
// See https://blogs.msdn.microsoft.com/ie/2006/12/06/file-uris-in-windows/ // protocols are actually exposed.
// for more info about valid windows URIs cli := NewDownloadClient(&DownloadConfig{})
idx := strings.Index(original, ":")
if idx == 1 { // Iterate through each downloader to see if a protocol was found.
original = "file:///" + original ok := false
for scheme := range cli.config.DownloaderMap {
if strings.ToLower(u.Scheme) == strings.ToLower(scheme) {
ok = true
} }
} }
return ok
}
// DownloadableURL processes a URL that may also be a file path and returns
// a completely valid URL representing the requested file. For example,
// the original URL might be "local/file.iso" which isn't a valid URL,
// and so DownloadableURL will return "file://local/file.iso"
// No other transformations are done to the path.
func DownloadableURL(original string) (string, error) {
var absPrefix, result string
absPrefix = ""
if runtime.GOOS == "windows" {
absPrefix = "/"
}
// Check that the user specified a UNC path, and promote it to an smb:// uri.
if strings.HasPrefix(original, "\\\\") && len(original) > 2 && original[2] != '?' {
result = filepath.ToSlash(original[2:])
return fmt.Sprintf("smb://%s", result), nil
}
// Fix the url if it's using bad characters commonly mistaken with a path.
original = filepath.ToSlash(original)
// Check to see that this is a parseable URL with a scheme and a host.
// If so, then just pass it through.
if u, err := url.Parse(original); err == nil && u.Scheme != "" && u.Host != "" {
return original, nil
}
// If it's a file scheme, then convert it back to a regular path so the next
// case which forces it to an absolute path, will correct it.
if u, err := url.Parse(original); err == nil && strings.ToLower(u.Scheme) == "file" {
original = u.Path
}
// If we're on Windows and we start with a slash, then this absolute path
// is wrong. Fix it up, so the next case can figure out the absolute path.
if rpath := strings.SplitN(original, "/", 2); rpath[0] == "" && runtime.GOOS == "windows" {
result = rpath[1]
} else {
result = original
}
// Since we should be some kind of path (relative or absolute), check
// that the file exists, then make it an absolute path so we can return an
// absolute uri.
if _, err := os.Stat(result); err == nil {
result, err = filepath.Abs(filepath.FromSlash(result))
if err != nil {
return "", err
}
result, err = filepath.EvalSymlinks(result)
if err != nil {
return "", err
}
result = filepath.Clean(result)
return fmt.Sprintf("file://%s%s", absPrefix, filepath.ToSlash(result)), nil
}
// Otherwise, check if it was originally an absolute path, and fix it if so.
if strings.HasPrefix(original, "/") {
return fmt.Sprintf("file://%s%s", absPrefix, result), nil
}
// Anything left should be a non-existent relative path. So fix it up here.
result = filepath.ToSlash(filepath.Clean(result))
return fmt.Sprintf("file://./%s", result), nil
}
// Force the parameter into a url. This will transform the parameter into
// a proper url, removing slashes, adding the proper prefix, etc.
func ValidatedURL(original string) (string, error) {
// See if the user failed to give a url
if ok, _ := regexp.MatchString("(?m)^[^[:punct:]]+://", original); !ok {
// So since no magic was found, this must be a path.
result, err := DownloadableURL(original)
if err == nil {
return ValidatedURL(result)
}
return "", err
}
// Verify that the url is parseable...just in case.
u, err := url.Parse(original) u, err := url.Parse(original)
if err != nil { if err != nil {
return "", err return "", err
} }
if u.Scheme == "" { // We should now have a url, so verify that it's a protocol we support.
u.Scheme = "file" if !SupportedProtocol(u) {
return "", fmt.Errorf("Unsupported protocol scheme! (%#v)", u)
} }
if u.Scheme == "file" { // We should now have a properly formatted and supported url
// Windows file handling is all sorts of tricky...
if runtime.GOOS == "windows" {
// If the path is using Windows-style slashes, URL parses
// it into the host field.
if u.Path == "" && strings.Contains(u.Host, `\`) {
u.Path = u.Host
u.Host = ""
}
}
// Only do the filepath transformations if the file appears
// to actually exist.
if _, err := os.Stat(u.Path); err == nil {
u.Path, err = filepath.Abs(u.Path)
if err != nil {
return "", err
}
u.Path, err = filepath.EvalSymlinks(u.Path)
if err != nil {
return "", err
}
u.Path = filepath.Clean(u.Path)
}
if runtime.GOOS == "windows" {
// Also replace all backslashes with forwardslashes since Windows
// users are likely to do this but the URL should actually only
// contain forward slashes.
u.Path = strings.Replace(u.Path, `\`, `/`, -1)
// prepend absolute windows paths with "/" so that when we
// compose u.String() below the outcome will be correct
// file:///c/blah syntax; otherwise u.String() will only add
// file:// which is not technically a correct windows URI
if filepath.IsAbs(u.Path) && !strings.HasPrefix(u.Path, "/") {
u.Path = "/" + u.Path
}
}
}
// Make sure it is lowercased
u.Scheme = strings.ToLower(u.Scheme)
// Verify that the scheme is something we support in our common downloader.
supported := []string{"file", "http", "https"}
found := false
for _, s := range supported {
if u.Scheme == s {
found = true
break
}
}
if !found {
return "", fmt.Errorf("Unsupported URL scheme: %s", u.Scheme)
}
return u.String(), nil return u.String(), nil
} }
@ -144,27 +182,38 @@ func DownloadableURL(original string) (string, error) {
func FileExistsLocally(original string) bool { func FileExistsLocally(original string) bool {
// original should be something like file://C:/my/path.iso // original should be something like file://C:/my/path.iso
u, _ := url.Parse(original)
fileURL, _ := url.Parse(original) // First create a dummy downloader so we can figure out which
fileExists := false // protocol to use.
cli := NewDownloadClient(&DownloadConfig{})
if fileURL.Scheme == "file" { d, ok := cli.config.DownloaderMap[u.Scheme]
// on windows, correct URI is file:///c:/blah/blah.iso. if !ok {
// url.Parse will pull out the scheme "file://" and leave the path as return false
// "/c:/blah/blah/iso". Here we remove this forward slash on absolute
// Windows file URLs before processing
// see https://blogs.msdn.microsoft.com/ie/2006/12/06/file-uris-in-windows/
// for more info about valid windows URIs
filePath := fileURL.Path
if runtime.GOOS == "windows" && len(filePath) > 0 && filePath[0] == '/' {
filePath = filePath[1:]
}
_, err := os.Stat(filePath)
if err != nil {
return fileExists
} else {
fileExists = true
}
} }
return fileExists
// Check to see that it's got a Local way of doing things.
local, ok := d.(LocalDownloader)
if !ok {
return true // XXX: Remote URLs short-circuit this logic.
}
// Figure out where we're at.
wd, err := os.Getwd()
if err != nil {
return false
}
// Now figure out the real path to the file.
realpath, err := local.toPath(wd, *u)
if err != nil {
return false
}
// Finally we can seek the truth via os.Stat.
_, err = os.Stat(realpath)
if err != nil {
return false
}
return true
} }

View File

@ -37,7 +37,28 @@ func TestChooseString(t *testing.T) {
} }
} }
func TestDownloadableURL(t *testing.T) { func TestValidatedURL(t *testing.T) {
// Invalid URL: has hex code in host
_, err := ValidatedURL("http://what%20.com")
if err == nil {
t.Fatalf("expected err : %s", err)
}
// Invalid: unsupported scheme
_, err = ValidatedURL("ftp://host.com/path")
if err == nil {
t.Fatalf("expected err : %s", err)
}
// Valid: http
u, err := ValidatedURL("HTTP://packer.io/path")
if err != nil {
t.Fatalf("err: %s", err)
}
if u != "http://packer.io/path" {
t.Fatalf("bad: %s", u)
}
cases := []struct { cases := []struct {
InputString string InputString string
@ -55,7 +76,7 @@ func TestDownloadableURL(t *testing.T) {
} }
for _, tc := range cases { for _, tc := range cases {
u, err := DownloadableURL(tc.InputString) u, err := ValidatedURL(tc.InputString)
if u != tc.OutputURL { if u != tc.OutputURL {
t.Fatal(fmt.Sprintf("Error with URL %s: got %s but expected %s", t.Fatal(fmt.Sprintf("Error with URL %s: got %s but expected %s",
tc.InputString, tc.OutputURL, u)) tc.InputString, tc.OutputURL, u))
@ -63,78 +84,127 @@ func TestDownloadableURL(t *testing.T) {
if (err != nil) != tc.ErrExpected { if (err != nil) != tc.ErrExpected {
if tc.ErrExpected == true { if tc.ErrExpected == true {
t.Fatal(fmt.Sprintf("Error with URL %s: we expected "+ t.Fatal(fmt.Sprintf("Error with URL %s: we expected "+
"DownloadableURL to return an error but didn't get one.", "ValidatedURL to return an error but didn't get one.",
tc.InputString)) tc.InputString))
} else { } else {
t.Fatal(fmt.Sprintf("Error with URL %s: we did not expect an "+ t.Fatal(fmt.Sprintf("Error with URL %s: we did not expect an "+
" error from DownloadableURL but we got: %s", " error from ValidatedURL but we got: %s",
tc.InputString, err)) tc.InputString, err))
} }
} }
} }
} }
func GetNativePathToTestFixtures(t *testing.T) string {
const path = "./test-fixtures"
res, err := filepath.Abs(path)
if err != nil {
t.Fatalf("err converting test-fixtures path into an absolute path : %s", err)
}
return res
}
func GetPortablePathToTestFixtures(t *testing.T) string {
res := GetNativePathToTestFixtures(t)
return filepath.ToSlash(res)
}
func TestDownloadableURL_WindowsFiles(t *testing.T) { func TestDownloadableURL_WindowsFiles(t *testing.T) {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
portablepath := GetPortablePathToTestFixtures(t)
nativepath := GetNativePathToTestFixtures(t)
dirCases := []struct { dirCases := []struct {
InputString string InputString string
OutputURL string OutputURL string
ErrExpected bool ErrExpected bool
}{ // TODO: add different directories }{ // TODO: add different directories
{ {
"C:\\Temp\\SomeDir\\myfile.txt", fmt.Sprintf("%s\\SomeDir\\myfile.txt", nativepath),
"file:///C:/Temp/SomeDir/myfile.txt", fmt.Sprintf("file:///%s/SomeDir/myfile.txt", portablepath),
false, false,
}, },
{ // need windows drive { // without the drive makes this native path a relative file:// uri
"\\Temp\\SomeDir\\myfile.txt", "test-fixtures\\SomeDir\\myfile.txt",
"", fmt.Sprintf("file:///%s/SomeDir/myfile.txt", portablepath),
true,
},
{ // need windows drive
"/Temp/SomeDir/myfile.txt",
"",
true,
},
{ // UNC paths; why not?
"\\\\?\\c:\\Temp\\SomeDir\\myfile.txt",
"",
true,
},
{
"file:///C:\\Temp\\SomeDir\\myfile.txt",
"file:///c:/Temp/SomeDir/myfile.txt",
false, false,
}, },
{ { // without the drive makes this native path a relative file:// uri
"file:///c:/Temp/Somedir/myfile.txt", "test-fixtures/SomeDir/myfile.txt",
"file:///c:/Temp/SomeDir/myfile.txt", fmt.Sprintf("file:///%s/SomeDir/myfile.txt", portablepath),
false,
},
{ // UNC paths being promoted to smb:// uri scheme.
fmt.Sprintf("\\\\localhost\\C$\\%s\\SomeDir\\myfile.txt", nativepath),
fmt.Sprintf("smb://localhost/C$/%s/SomeDir/myfile.txt", portablepath),
false,
},
{ // Absolute uri (incorrect slash type)
fmt.Sprintf("file:///%s\\SomeDir\\myfile.txt", nativepath),
fmt.Sprintf("file:///%s/SomeDir/myfile.txt", portablepath),
false,
},
{ // Absolute uri (existing and mis-spelled)
fmt.Sprintf("file:///%s/Somedir/myfile.txt", nativepath),
fmt.Sprintf("file:///%s/SomeDir/myfile.txt", portablepath),
false,
},
{ // Absolute path (non-existing)
"\\absolute\\path\\to\\non-existing\\file.txt",
"file:///absolute/path/to/non-existing/file.txt",
false,
},
{ // Absolute paths (existing)
fmt.Sprintf("%s/SomeDir/myfile.txt", nativepath),
fmt.Sprintf("file:///%s/SomeDir/myfile.txt", portablepath),
false,
},
{ // Relative path (non-existing)
"./nonexisting/relative/path/to/file.txt",
"file://./nonexisting/relative/path/to/file.txt",
false,
},
{ // Relative path (existing)
"./test-fixtures/SomeDir/myfile.txt",
fmt.Sprintf("file:///%s/SomeDir/myfile.txt", portablepath),
false,
},
{ // Absolute uri (existing and with `/` prefix)
fmt.Sprintf("file:///%s/SomeDir/myfile.txt", portablepath),
fmt.Sprintf("file:///%s/SomeDir/myfile.txt", portablepath),
false,
},
{ // Absolute uri (non-existing and with `/` prefix)
"file:///path/to/non-existing/file.txt",
"file:///path/to/non-existing/file.txt",
false,
},
{ // Absolute uri (non-existing and missing `/` prefix)
"file://path/to/non-existing/file.txt",
"file://path/to/non-existing/file.txt",
false,
},
{ // Absolute uri and volume (non-existing and with `/` prefix)
"file:///T:/path/to/non-existing/file.txt",
"file:///T:/path/to/non-existing/file.txt",
false,
},
{ // Absolute uri and volume (non-existing and missing `/` prefix)
"file://T:/path/to/non-existing/file.txt",
"file://T:/path/to/non-existing/file.txt",
false, false,
}, },
} }
// create absolute-pathed tempfile to play with
err := os.Mkdir("C:\\Temp\\SomeDir", 0755)
if err != nil {
t.Fatalf("err creating test dir: %s", err)
}
fi, err := os.Create("C:\\Temp\\SomeDir\\myfile.txt")
if err != nil {
t.Fatalf("err creating test file: %s", err)
}
fi.Close()
defer os.Remove("C:\\Temp\\SomeDir\\myfile.txt")
defer os.Remove("C:\\Temp\\SomeDir")
// Run through test cases to make sure they all parse correctly // Run through test cases to make sure they all parse correctly
for _, tc := range dirCases { for idx, tc := range dirCases {
u, err := DownloadableURL(tc.InputString) u, err := DownloadableURL(tc.InputString)
if (err != nil) != tc.ErrExpected { if (err != nil) != tc.ErrExpected {
t.Fatalf("Test Case failed: Expected err = %#v, err = %#v, input = %s", t.Fatalf("Test Case %d failed: Expected err = %#v, err = %#v, input = %s",
tc.ErrExpected, err, tc.InputString) idx, tc.ErrExpected, err, tc.InputString)
} }
if u != tc.OutputURL { if u != tc.OutputURL {
t.Fatalf("Test Case failed: Expected %s but received %s from input %s", t.Fatalf("Test Case %d failed: Expected %s but received %s from input %s",
tc.OutputURL, u, tc.InputString) idx, tc.OutputURL, u, tc.InputString)
} }
} }
} }
@ -154,10 +224,12 @@ func TestDownloadableURL_FilePaths(t *testing.T) {
} }
tfPath = filepath.Clean(tfPath) tfPath = filepath.Clean(tfPath)
filePrefix := "file://" filePrefix := "file://"
// If we're running windows, then absolute URIs are `/`-prefixed.
platformPrefix := ""
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
filePrefix += "/" platformPrefix = "/"
} }
// Relative filepath. We run this test in a func so that // Relative filepath. We run this test in a func so that
@ -180,8 +252,9 @@ func TestDownloadableURL_FilePaths(t *testing.T) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
expected := fmt.Sprintf("%s%s", expected := fmt.Sprintf("%s%s%s",
filePrefix, filePrefix,
platformPrefix,
strings.Replace(tfPath, `\`, `/`, -1)) strings.Replace(tfPath, `\`, `/`, -1))
if u != expected { if u != expected {
t.Fatalf("unexpected: %#v != %#v", u, expected) t.Fatalf("unexpected: %#v != %#v", u, expected)
@ -189,21 +262,22 @@ func TestDownloadableURL_FilePaths(t *testing.T) {
}() }()
// Test some cases with and without a schema prefix // Test some cases with and without a schema prefix
for _, prefix := range []string{"", filePrefix} { for _, prefix := range []string{"", filePrefix + platformPrefix} {
// Nonexistent file // Nonexistent file
_, err = DownloadableURL(prefix + "i/dont/exist") _, err = DownloadableURL(prefix + "i/dont/exist")
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
// Good file // Good file (absolute)
u, err := DownloadableURL(prefix + tfPath) u, err := DownloadableURL(prefix + tfPath)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
expected := fmt.Sprintf("%s%s", expected := fmt.Sprintf("%s%s%s",
filePrefix, filePrefix,
platformPrefix,
strings.Replace(tfPath, `\`, `/`, -1)) strings.Replace(tfPath, `\`, `/`, -1))
if u != expected { if u != expected {
t.Fatalf("unexpected: %s != %s", u, expected) t.Fatalf("unexpected: %s != %s", u, expected)
@ -211,39 +285,28 @@ func TestDownloadableURL_FilePaths(t *testing.T) {
} }
} }
func test_FileExistsLocally(t *testing.T) { func TestFileExistsLocally(t *testing.T) {
if runtime.GOOS == "windows" { portablepath := GetPortablePathToTestFixtures(t)
dirCases := []struct {
Input string
Output bool
}{
// file exists locally
{"file:///C:/Temp/SomeDir/myfile.txt", true},
// file is not supposed to exist locally
{"https://myfile.iso", true},
// file does not exist locally
{"file:///C/i/dont/exist", false},
}
// create absolute-pathed tempfile to play with
err := os.Mkdir("C:\\Temp\\SomeDir", 0755)
if err != nil {
t.Fatalf("err creating test dir: %s", err)
}
fi, err := os.Create("C:\\Temp\\SomeDir\\myfile.txt")
if err != nil {
t.Fatalf("err creating test file: %s", err)
}
fi.Close()
defer os.Remove("C:\\Temp\\SomeDir\\myfile.txt")
defer os.Remove("C:\\Temp\\SomeDir")
// Run through test cases to make sure they all parse correctly dirCases := []struct {
for _, tc := range dirCases { Input string
fileOK := FileExistsLocally(tc.Input) Output bool
if !fileOK { }{
t.Fatalf("Test Case failed: Expected %#v, received = %#v, input = %s", // file exists locally
tc.Output, fileOK, tc.Input) {fmt.Sprintf("file://%s/SomeDir/myfile.txt", portablepath), true},
} // remote protocols short-circuit and are considered to exist locally
{"https://myfile.iso", true},
// non-existent protocols do not exist and hence fail
{"nonexistent-protocol://myfile.iso", false},
// file does not exist locally
{"file:///C/i/dont/exist", false},
}
// Run through test cases to make sure they all parse correctly
for _, tc := range dirCases {
fileOK := FileExistsLocally(tc.Input)
if fileOK != tc.Output {
t.Fatalf("Test Case failed: Expected %#v, received = %#v, input = %s",
tc.Output, fileOK, tc.Input)
} }
} }
} }

View File

@ -10,12 +10,19 @@ import (
"errors" "errors"
"fmt" "fmt"
"hash" "hash"
"io"
"log" "log"
"net/http"
"net/url" "net/url"
"os" "os"
"path"
"runtime" "runtime"
"strings"
)
// imports related to each Downloader implementation
import (
"io"
"net/http"
"path/filepath"
) )
// DownloadConfig is the configuration given to instantiate a new // DownloadConfig is the configuration given to instantiate a new
@ -75,23 +82,38 @@ func HashForType(t string) hash.Hash {
// NewDownloadClient returns a new DownloadClient for the given // NewDownloadClient returns a new DownloadClient for the given
// configuration. // configuration.
func NewDownloadClient(c *DownloadConfig) *DownloadClient { func NewDownloadClient(c *DownloadConfig) *DownloadClient {
const mtu = 1500 /* ethernet */ - 20 /* ipv4 */ - 20 /* tcp */
// Create downloader map if it hasn't been specified already.
if c.DownloaderMap == nil { if c.DownloaderMap == nil {
c.DownloaderMap = map[string]Downloader{ c.DownloaderMap = map[string]Downloader{
"file": &FileDownloader{bufferSize: nil},
"http": &HTTPDownloader{userAgent: c.UserAgent}, "http": &HTTPDownloader{userAgent: c.UserAgent},
"https": &HTTPDownloader{userAgent: c.UserAgent}, "https": &HTTPDownloader{userAgent: c.UserAgent},
"smb": &SMBDownloader{bufferSize: nil},
} }
} }
return &DownloadClient{config: c} return &DownloadClient{config: c}
} }
// A downloader is responsible for actually taking a remote URL and // A downloader implements the ability to transfer, cancel, or resume a file.
// downloading it.
type Downloader interface { type Downloader interface {
Resume()
Cancel() Cancel()
Progress() uint64
Total() uint64
}
// A LocalDownloader is responsible for converting a uri to a local path
// that the platform can open directly.
type LocalDownloader interface {
toPath(string, url.URL) (string, error)
}
// A RemoteDownloader is responsible for actually taking a remote URL and
// downloading it.
type RemoteDownloader interface {
Download(*os.File, *url.URL) error Download(*os.File, *url.URL) error
Progress() uint
Total() uint
} }
func (d *DownloadClient) Cancel() { func (d *DownloadClient) Cancel() {
@ -105,62 +127,64 @@ func (d *DownloadClient) Get() (string, error) {
return d.config.TargetPath, nil return d.config.TargetPath, nil
} }
/* parse the configuration url into a net/url object */
u, err := url.Parse(d.config.Url) u, err := url.Parse(d.config.Url)
if err != nil { if err != nil {
return "", err return "", err
} }
log.Printf("Parsed URL: %#v", u) log.Printf("Parsed URL: %#v", u)
// Files when we don't copy the file are special cased. /* use the current working directory as the base for relative uri's */
var f *os.File cwd, err := os.Getwd()
if err != nil {
return "", err
}
// Determine which is the correct downloader to use
var finalPath string var finalPath string
sourcePath := ""
if u.Scheme == "file" && !d.config.CopyFile {
// This is special case for relative path in this case user specify
// file:../ and after parse destination goes to Opaque
if u.Path != "" {
// If url.Path is set just use this
finalPath = u.Path
} else if u.Opaque != "" {
// otherwise try url.Opaque
finalPath = u.Opaque
}
// This is a special case where we use a source file that already exists
// locally and we don't make a copy. Normally we would copy or download.
log.Printf("[DEBUG] Using local file: %s", finalPath)
// Remove forward slash on absolute Windows file URLs before processing var ok bool
if runtime.GOOS == "windows" && len(finalPath) > 0 && finalPath[0] == '/' { d.downloader, ok = d.config.DownloaderMap[u.Scheme]
finalPath = finalPath[1:] if !ok {
} return "", fmt.Errorf("No downloader for scheme: %s", u.Scheme)
}
// Keep track of the source so we can make sure not to delete this later remote, ok := d.downloader.(RemoteDownloader)
sourcePath = finalPath if !ok {
if _, err = os.Stat(finalPath); err != nil { return "", fmt.Errorf("Unable to treat uri scheme %s as a Downloader. : %T", u.Scheme, d.downloader)
return "", err }
}
} else { local, ok := d.downloader.(LocalDownloader)
if !ok && !d.config.CopyFile {
d.config.CopyFile = true
}
// If we're copying the file, then just use the actual downloader
if d.config.CopyFile {
var f *os.File
finalPath = d.config.TargetPath finalPath = d.config.TargetPath
var ok bool
d.downloader, ok = d.config.DownloaderMap[u.Scheme]
if !ok {
return "", fmt.Errorf("No downloader for scheme: %s", u.Scheme)
}
// Otherwise, download using the downloader.
f, err = os.OpenFile(finalPath, os.O_RDWR|os.O_CREATE, os.FileMode(0666)) f, err = os.OpenFile(finalPath, os.O_RDWR|os.O_CREATE, os.FileMode(0666))
if err != nil { if err != nil {
return "", err return "", err
} }
log.Printf("[DEBUG] Downloading: %s", u.String()) log.Printf("[DEBUG] Downloading: %s", u.String())
err = d.downloader.Download(f, u) err = remote.Download(f, u)
f.Close() f.Close()
if err != nil { if err != nil {
return "", err return "", err
} }
// Otherwise if our Downloader is a LocalDownloader we can just use the
// path after transforming it.
} else {
finalPath, err = local.toPath(cwd, *u)
if err != nil {
return "", err
}
log.Printf("[DEBUG] Using local file: %s", finalPath)
} }
if d.config.Hash != nil { if d.config.Hash != nil {
@ -168,7 +192,7 @@ func (d *DownloadClient) Get() (string, error) {
verify, err = d.VerifyChecksum(finalPath) verify, err = d.VerifyChecksum(finalPath)
if err == nil && !verify { if err == nil && !verify {
// Only delete the file if we made a copy or downloaded it // Only delete the file if we made a copy or downloaded it
if sourcePath != finalPath { if d.config.CopyFile {
os.Remove(finalPath) os.Remove(finalPath)
} }
@ -181,7 +205,6 @@ func (d *DownloadClient) Get() (string, error) {
return finalPath, err return finalPath, err
} }
// PercentProgress returns the download progress as a percentage.
func (d *DownloadClient) PercentProgress() int { func (d *DownloadClient) PercentProgress() int {
if d.downloader == nil { if d.downloader == nil {
return -1 return -1
@ -212,17 +235,21 @@ func (d *DownloadClient) VerifyChecksum(path string) (bool, error) {
// HTTPDownloader is an implementation of Downloader that downloads // HTTPDownloader is an implementation of Downloader that downloads
// files over HTTP. // files over HTTP.
type HTTPDownloader struct { type HTTPDownloader struct {
progress uint current uint64
total uint total uint64
userAgent string userAgent string
} }
func (*HTTPDownloader) Cancel() { func (d *HTTPDownloader) Cancel() {
// TODO(mitchellh): Implement
}
func (d *HTTPDownloader) Resume() {
// TODO(mitchellh): Implement // TODO(mitchellh): Implement
} }
func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
log.Printf("Starting download: %s", src.String()) log.Printf("Starting download over HTTP: %s", src.String())
// Seek to the beginning by default // Seek to the beginning by default
if _, err := dst.Seek(0, 0); err != nil { if _, err := dst.Seek(0, 0); err != nil {
@ -230,7 +257,7 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
} }
// Reset our progress // Reset our progress
d.progress = 0 d.current = 0
// Make the request. We first make a HEAD request so we can check // Make the request. We first make a HEAD request so we can check
// if the server supports range queries. If the server/URL doesn't // if the server supports range queries. If the server/URL doesn't
@ -258,7 +285,8 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
if fi, err := dst.Stat(); err == nil { if fi, err := dst.Stat(); err == nil {
if _, err = dst.Seek(0, os.SEEK_END); err == nil { if _, err = dst.Seek(0, os.SEEK_END); err == nil {
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", fi.Size())) req.Header.Set("Range", fmt.Sprintf("bytes=%d-", fi.Size()))
d.progress = uint(fi.Size())
d.current = uint64(fi.Size())
} }
} }
} }
@ -272,7 +300,8 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
return err return err
} }
d.total = d.progress + uint(resp.ContentLength) d.total = d.current + uint64(resp.ContentLength)
var buffer [4096]byte var buffer [4096]byte
for { for {
n, err := resp.Body.Read(buffer[:]) n, err := resp.Body.Read(buffer[:])
@ -280,7 +309,7 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
return err return err
} }
d.progress += uint(n) d.current += uint64(n)
if _, werr := dst.Write(buffer[:n]); werr != nil { if _, werr := dst.Write(buffer[:n]); werr != nil {
return werr return werr
@ -290,14 +319,253 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
break break
} }
} }
return nil return nil
} }
func (d *HTTPDownloader) Progress() uint { func (d *HTTPDownloader) Progress() uint64 {
return d.progress return d.current
} }
func (d *HTTPDownloader) Total() uint { func (d *HTTPDownloader) Total() uint64 {
return d.total return d.total
} }
// FileDownloader is an implementation of Downloader that downloads
// files using the regular filesystem.
type FileDownloader struct {
bufferSize *uint
active bool
current uint64
total uint64
}
func (d *FileDownloader) Progress() uint64 {
return d.current
}
func (d *FileDownloader) Total() uint64 {
return d.total
}
func (d *FileDownloader) Cancel() {
d.active = false
}
func (d *FileDownloader) Resume() {
// TODO: Implement
}
func (d *FileDownloader) toPath(base string, uri url.URL) (string, error) {
var result string
// absolute path -- file://c:/absolute/path -> c:/absolute/path
if strings.HasSuffix(uri.Host, ":") {
result = path.Join(uri.Host, uri.Path)
// semi-absolute path (current drive letter)
// -- file:///absolute/path -> drive:/absolute/path
} else if uri.Host == "" && strings.HasPrefix(uri.Path, "/") {
apath := uri.Path
components := strings.Split(apath, "/")
volume := filepath.VolumeName(base)
// semi-absolute absolute path (includes volume letter)
// -- file://drive:/path -> drive:/absolute/path
if len(components) > 1 && strings.HasSuffix(components[1], ":") {
volume = components[1]
apath = path.Join(components[2:]...)
}
result = path.Join(volume, apath)
// relative path -- file://./relative/path -> ./relative/path
} else if uri.Host == "." {
result = path.Join(base, uri.Path)
// relative path -- file://relative/path -> ./relative/path
} else {
result = path.Join(base, uri.Host, uri.Path)
}
return filepath.ToSlash(result), nil
}
func (d *FileDownloader) Download(dst *os.File, src *url.URL) error {
d.active = false
/* check the uri's scheme to make sure it matches */
if src == nil || src.Scheme != "file" {
return fmt.Errorf("Unexpected uri scheme: %s", src.Scheme)
}
uri := src
/* use the current working directory as the base for relative uri's */
cwd, err := os.Getwd()
if err != nil {
return err
}
/* determine which uri format is being used and convert to a real path */
realpath, err := d.toPath(cwd, *uri)
if err != nil {
return err
}
/* download the file using the operating system's facilities */
d.current = 0
d.active = true
f, err := os.Open(realpath)
if err != nil {
return err
}
defer f.Close()
// get the file size
fi, err := f.Stat()
if err != nil {
return err
}
d.total = uint64(fi.Size())
// no bufferSize specified, so copy synchronously.
if d.bufferSize == nil {
var n int64
n, err = io.Copy(dst, f)
d.active = false
d.current += uint64(n)
// use a goro in case someone else wants to enable cancel/resume
} else {
errch := make(chan error)
go func(d *FileDownloader, r io.Reader, w io.Writer, e chan error) {
for d.active {
n, err := io.CopyN(w, r, int64(*d.bufferSize))
if err != nil {
break
}
d.current += uint64(n)
}
d.active = false
e <- err
}(d, f, dst, errch)
// ...and we spin until it's done
err = <-errch
}
f.Close()
return err
}
// SMBDownloader is an implementation of Downloader that downloads
// files using the "\\" path format on Windows
type SMBDownloader struct {
bufferSize *uint
active bool
current uint64
total uint64
}
func (d *SMBDownloader) Progress() uint64 {
return d.current
}
func (d *SMBDownloader) Total() uint64 {
return d.total
}
func (d *SMBDownloader) Cancel() {
d.active = false
}
func (d *SMBDownloader) Resume() {
// TODO: Implement
}
func (d *SMBDownloader) toPath(base string, uri url.URL) (string, error) {
const UNCPrefix = string(os.PathSeparator) + string(os.PathSeparator)
if runtime.GOOS != "windows" {
return "", fmt.Errorf("Support for SMB based uri's are not supported on %s", runtime.GOOS)
}
return UNCPrefix + filepath.ToSlash(path.Join(uri.Host, uri.Path)), nil
}
func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error {
/* first we warn the world if we're not running windows */
if runtime.GOOS != "windows" {
return fmt.Errorf("Support for SMB based uri's are not supported on %s", runtime.GOOS)
}
d.active = false
/* convert the uri using the net/url module to a UNC path */
if src == nil || src.Scheme != "smb" {
return fmt.Errorf("Unexpected uri scheme: %s", src.Scheme)
}
uri := src
/* use the current working directory as the base for relative uri's */
cwd, err := os.Getwd()
if err != nil {
return err
}
/* convert uri to an smb-path */
realpath, err := d.toPath(cwd, *uri)
if err != nil {
return err
}
/* Open up the "\\"-prefixed path using the Windows filesystem */
d.current = 0
d.active = true
f, err := os.Open(realpath)
if err != nil {
return err
}
defer f.Close()
// get the file size (at the risk of performance)
fi, err := f.Stat()
if err != nil {
return err
}
d.total = uint64(fi.Size())
// no bufferSize specified, so copy synchronously.
if d.bufferSize == nil {
var n int64
n, err = io.Copy(dst, f)
d.active = false
d.current += uint64(n)
// use a goro in case someone else wants to enable cancel/resume
} else {
errch := make(chan error)
go func(d *SMBDownloader, r io.Reader, w io.Writer, e chan error) {
for d.active {
n, err := io.CopyN(w, r, int64(*d.bufferSize))
if err != nil {
break
}
d.current += uint64(n)
}
d.active = false
e <- err
}(d, f, dst, errch)
// ...and as usual we spin until it's done
err = <-errch
}
f.Close()
return err
}

View File

@ -8,7 +8,9 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath"
"runtime" "runtime"
"strings"
"testing" "testing"
) )
@ -56,6 +58,7 @@ func TestDownloadClient_basic(t *testing.T) {
client := NewDownloadClient(&DownloadConfig{ client := NewDownloadClient(&DownloadConfig{
Url: ts.URL + "/basic.txt", Url: ts.URL + "/basic.txt",
TargetPath: tf.Name(), TargetPath: tf.Name(),
CopyFile: true,
}) })
path, err := client.Get() path, err := client.Get()
@ -91,6 +94,7 @@ func TestDownloadClient_checksumBad(t *testing.T) {
TargetPath: tf.Name(), TargetPath: tf.Name(),
Hash: HashForType("md5"), Hash: HashForType("md5"),
Checksum: checksum, Checksum: checksum,
CopyFile: true,
}) })
if _, err := client.Get(); err == nil { if _, err := client.Get(); err == nil {
t.Fatal("should error") t.Fatal("should error")
@ -115,6 +119,7 @@ func TestDownloadClient_checksumGood(t *testing.T) {
TargetPath: tf.Name(), TargetPath: tf.Name(),
Hash: HashForType("md5"), Hash: HashForType("md5"),
Checksum: checksum, Checksum: checksum,
CopyFile: true,
}) })
path, err := client.Get() path, err := client.Get()
if err != nil { if err != nil {
@ -145,6 +150,7 @@ func TestDownloadClient_checksumNoDownload(t *testing.T) {
TargetPath: "./test-fixtures/root/another.txt", TargetPath: "./test-fixtures/root/another.txt",
Hash: HashForType("md5"), Hash: HashForType("md5"),
Checksum: checksum, Checksum: checksum,
CopyFile: true,
}) })
path, err := client.Get() path, err := client.Get()
if err != nil { if err != nil {
@ -183,6 +189,7 @@ func TestDownloadClient_resume(t *testing.T) {
client := NewDownloadClient(&DownloadConfig{ client := NewDownloadClient(&DownloadConfig{
Url: ts.URL, Url: ts.URL,
TargetPath: tf.Name(), TargetPath: tf.Name(),
CopyFile: true,
}) })
path, err := client.Get() path, err := client.Get()
if err != nil { if err != nil {
@ -240,6 +247,7 @@ func TestDownloadClient_usesDefaultUserAgent(t *testing.T) {
config := &DownloadConfig{ config := &DownloadConfig{
Url: server.URL, Url: server.URL,
TargetPath: tf.Name(), TargetPath: tf.Name(),
CopyFile: true,
} }
client := NewDownloadClient(config) client := NewDownloadClient(config)
@ -271,6 +279,7 @@ func TestDownloadClient_setsUserAgent(t *testing.T) {
Url: server.URL, Url: server.URL,
TargetPath: tf.Name(), TargetPath: tf.Name(),
UserAgent: "fancy user agent", UserAgent: "fancy user agent",
CopyFile: true,
} }
client := NewDownloadClient(config) client := NewDownloadClient(config)
@ -351,6 +360,7 @@ func TestDownloadFileUrl(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Unable to detect working directory: %s", err) t.Fatalf("Unable to detect working directory: %s", err)
} }
cwd = filepath.ToSlash(cwd)
// source_path is a file path and source is a network path // source_path is a file path and source is a network path
sourcePath := fmt.Sprintf("%s/test-fixtures/fileurl/%s", cwd, "cake") sourcePath := fmt.Sprintf("%s/test-fixtures/fileurl/%s", cwd, "cake")
@ -376,11 +386,116 @@ func TestDownloadFileUrl(t *testing.T) {
// Verify that we fail to match the checksum // Verify that we fail to match the checksum
_, err = client.Get() _, err = client.Get()
if err.Error() != "checksums didn't match expected: 6e6f7065" { if err.Error() != "checksums didn't match expected: 6e6f7065" {
t.Fatalf("Unexpected failure; expected checksum not to match") t.Fatalf("Unexpected failure; expected checksum not to match. Error was \"%v\"", err)
} }
if _, err = os.Stat(sourcePath); err != nil { if _, err = os.Stat(sourcePath); err != nil {
t.Errorf("Could not stat source file: %s", sourcePath) t.Errorf("Could not stat source file: %s", sourcePath)
} }
}
// SimulateFileUriDownload is a simple utility function that converts a uri
// into a testable file path whilst ignoring a correct checksum match, stripping
// UNC path info, and then calling stat to ensure the correct file exists.
// (used by TestFileUriTransforms)
func SimulateFileUriDownload(t *testing.T, uri string) (string, error) {
// source_path is a file path and source is a network path
source := fmt.Sprintf(uri)
t.Logf("Trying to download %s", source)
config := &DownloadConfig{
Url: source,
// This should be wrong. We want to make sure we don't delete
Checksum: []byte("nope"),
Hash: HashForType("sha256"),
CopyFile: false,
}
// go go go
client := NewDownloadClient(config)
path, err := client.Get()
// ignore any non-important checksum errors if it's not a unc path
if !strings.HasPrefix(path, "\\\\") && err.Error() != "checksums didn't match expected: 6e6f7065" {
t.Fatalf("Unexpected failure; expected checksum not to match")
}
// if it's a unc path, then remove the host and share name so we don't have
// to force the user to enable ADMIN$ and Windows File Sharing
if strings.HasPrefix(path, "\\\\") {
res := strings.SplitN(path, "/", 3)
path = "/" + res[2]
}
if _, err = os.Stat(path); err != nil {
t.Errorf("Could not stat source file: %s", path)
}
return path, err
}
// TestFileUriTransforms tests the case where we use a local file uri
// for iso_url. There's a few different formats that a file uri can exist as
// and so we try to test the most useful and common ones.
func TestFileUriTransforms(t *testing.T) {
const testpath = /* have your */ "test-fixtures/fileurl/cake" /* and eat it too */
const host = "localhost"
var cwd string
var volume string
var share string
cwd, err := os.Getwd()
if err != nil {
t.Fatalf("Unable to detect working directory: %s", err)
return
}
cwd = filepath.ToSlash(cwd)
volume = filepath.VolumeName(cwd)
share = volume
// if a volume was found (on windows), replace the ':' from
// C: to C$ to convert it into a hidden windows share.
if len(share) > 1 && share[len(share)-1] == ':' {
share = share[:len(share)-1] + "$"
}
cwd = cwd[len(volume):]
t.Logf("TestFileUriTransforms : Running with cwd : '%s'", cwd)
t.Logf("TestFileUriTransforms : Running with volume : '%s'", volume)
// ./relative/path -> ./relative/path
// /absolute/path -> /absolute/path
// c:/windows/absolute -> c:/windows/absolute
testcases := []string{
"./%s",
cwd + "/%s",
volume + cwd + "/%s",
}
// all regular slashed testcases
for _, testcase := range testcases {
uri := "file://" + fmt.Sprintf(testcase, testpath)
t.Logf("TestFileUriTransforms : Trying Uri '%s'", uri)
res, err := SimulateFileUriDownload(t, uri)
if err != nil {
t.Errorf("Unable to transform uri '%s' into a path : %v", uri, err)
}
t.Logf("TestFileUriTransforms : Result Path '%s'", res)
}
// smb protocol depends on platform support which currently
// only exists on windows.
if runtime.GOOS == "windows" {
// ...and finally the oddball windows native path
// smb://host/sharename/file -> \\host\sharename\file
testcase := host + "/" + share + "/" + cwd[1:] + "/%s"
uri := "smb://" + fmt.Sprintf(testcase, testpath)
t.Logf("TestFileUriTransforms : Trying Uri '%s'", uri)
res, err := SimulateFileUriDownload(t, uri)
if err != nil {
t.Errorf("Unable to transform uri '%s' into a path", uri)
return
}
t.Logf("TestFileUriTransforms : Result Path '%s'", res)
}
} }

View File

@ -111,7 +111,7 @@ func (c *ISOConfig) Prepare(ctx *interpolate.Context) (warnings []string, errs [
c.ISOChecksum = strings.ToLower(c.ISOChecksum) c.ISOChecksum = strings.ToLower(c.ISOChecksum)
for i, url := range c.ISOUrls { for i, url := range c.ISOUrls {
url, err := DownloadableURL(url) url, err := ValidatedURL(url)
if err != nil { if err != nil {
errs = append( errs = append(
errs, fmt.Errorf("Failed to parse iso_url %d: %s", i+1, err)) errs, fmt.Errorf("Failed to parse iso_url %d: %s", i+1, err))

View File