diff --git a/common/download.go b/common/download.go index 2a0b37ee7..b5798b76c 100644 --- a/common/download.go +++ b/common/download.go @@ -43,6 +43,10 @@ type DownloadConfig struct { // for the downloader will be used to verify with this checksum after // it is downloaded. Checksum []byte + + // What to use for the user agent for HTTP requests. If set to "", use the + // default user agent provided by Go. + UserAgent string } // A DownloadClient helps download, verify checksums, etc. @@ -73,8 +77,8 @@ func HashForType(t string) hash.Hash { func NewDownloadClient(c *DownloadConfig) *DownloadClient { if c.DownloaderMap == nil { c.DownloaderMap = map[string]Downloader{ - "http": new(HTTPDownloader), - "https": new(HTTPDownloader), + "http": &HTTPDownloader{userAgent: c.UserAgent}, + "https": &HTTPDownloader{userAgent: c.UserAgent}, } } @@ -182,8 +186,9 @@ func (d *DownloadClient) VerifyChecksum(path string) (bool, error) { // HTTPDownloader is an implementation of Downloader that downloads // files over HTTP. type HTTPDownloader struct { - progress uint - total uint + progress uint + total uint + userAgent string } func (*HTTPDownloader) Cancel() { @@ -197,6 +202,10 @@ func (d *HTTPDownloader) Download(dst io.Writer, src *url.URL) error { return err } + if d.userAgent != "" { + req.Header.Set("User-Agent", d.userAgent) + } + httpClient := &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, diff --git a/common/download_test.go b/common/download_test.go index f1fec941e..57b4ba7bc 100644 --- a/common/download_test.go +++ b/common/download_test.go @@ -4,6 +4,8 @@ import ( "crypto/md5" "encoding/hex" "io/ioutil" + "net/http" + "net/http/httptest" "os" "testing" ) @@ -41,6 +43,91 @@ func TestDownloadClient_VerifyChecksum(t *testing.T) { } } +func TestDownloadClientUsesDefaultUserAgent(t *testing.T) { + tf, err := ioutil.TempFile("", "packer") + if err != nil { + t.Fatalf("tempfile error: %s", err) + } + defer os.Remove(tf.Name()) + + defaultUserAgent := "" + asserted := false + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if defaultUserAgent == "" { + defaultUserAgent = r.UserAgent() + } else { + incomingUserAgent := r.UserAgent() + if incomingUserAgent != defaultUserAgent { + t.Fatalf("Expected user agent %s, got: %s", defaultUserAgent, incomingUserAgent) + } + + asserted = true + } + })) + + req, err := http.NewRequest("GET", server.URL, nil) + if err != nil { + t.Fatal(err) + } + + httpClient := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + }, + } + + _, err = httpClient.Do(req) + if err != nil { + t.Fatal(err) + } + + config := &DownloadConfig{ + Url: server.URL, + TargetPath: tf.Name(), + } + + client := NewDownloadClient(config) + _, err = client.Get() + if err != nil { + t.Fatal(err) + } + + if !asserted { + t.Fatal("User-Agent never observed") + } +} + +func TestDownloadClientSetsUserAgent(t *testing.T) { + tf, err := ioutil.TempFile("", "packer") + if err != nil { + t.Fatalf("tempfile error: %s", err) + } + defer os.Remove(tf.Name()) + + asserted := false + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + asserted = true + if r.UserAgent() != "fancy user agent" { + t.Fatalf("Expected useragent fancy user agent, got: %s", r.UserAgent()) + } + })) + config := &DownloadConfig{ + Url: server.URL, + TargetPath: tf.Name(), + UserAgent: "fancy user agent", + } + + client := NewDownloadClient(config) + _, err = client.Get() + if err != nil { + t.Fatal(err) + } + + if !asserted { + t.Fatal("HTTP request never made") + } +} + func TestHashForType(t *testing.T) { if h := HashForType("md5"); h == nil { t.Fatalf("md5 hash is nil")