HTTPDownloader uses UserAgent from DownloadConfig

This commit is contained in:
Mark Rushakoff 2014-01-09 08:41:34 -08:00
parent 73e57d918f
commit 9e5c0f6c6a
2 changed files with 100 additions and 4 deletions

View File

@ -43,6 +43,10 @@ type DownloadConfig struct {
// for the downloader will be used to verify with this checksum after // for the downloader will be used to verify with this checksum after
// it is downloaded. // it is downloaded.
Checksum []byte Checksum []byte
// What to use for the user agent for HTTP requests. If set to "", use the
// default user agent provided by Go.
UserAgent string
} }
// A DownloadClient helps download, verify checksums, etc. // A DownloadClient helps download, verify checksums, etc.
@ -73,8 +77,8 @@ func HashForType(t string) hash.Hash {
func NewDownloadClient(c *DownloadConfig) *DownloadClient { func NewDownloadClient(c *DownloadConfig) *DownloadClient {
if c.DownloaderMap == nil { if c.DownloaderMap == nil {
c.DownloaderMap = map[string]Downloader{ c.DownloaderMap = map[string]Downloader{
"http": new(HTTPDownloader), "http": &HTTPDownloader{userAgent: c.UserAgent},
"https": new(HTTPDownloader), "https": &HTTPDownloader{userAgent: c.UserAgent},
} }
} }
@ -182,8 +186,9 @@ func (d *DownloadClient) VerifyChecksum(path string) (bool, error) {
// HTTPDownloader is an implementation of Downloader that downloads // HTTPDownloader is an implementation of Downloader that downloads
// files over HTTP. // files over HTTP.
type HTTPDownloader struct { type HTTPDownloader struct {
progress uint progress uint
total uint total uint
userAgent string
} }
func (*HTTPDownloader) Cancel() { func (*HTTPDownloader) Cancel() {
@ -197,6 +202,10 @@ func (d *HTTPDownloader) Download(dst io.Writer, src *url.URL) error {
return err return err
} }
if d.userAgent != "" {
req.Header.Set("User-Agent", d.userAgent)
}
httpClient := &http.Client{ httpClient := &http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,

View File

@ -4,6 +4,8 @@ import (
"crypto/md5" "crypto/md5"
"encoding/hex" "encoding/hex"
"io/ioutil" "io/ioutil"
"net/http"
"net/http/httptest"
"os" "os"
"testing" "testing"
) )
@ -41,6 +43,91 @@ func TestDownloadClient_VerifyChecksum(t *testing.T) {
} }
} }
func TestDownloadClientUsesDefaultUserAgent(t *testing.T) {
tf, err := ioutil.TempFile("", "packer")
if err != nil {
t.Fatalf("tempfile error: %s", err)
}
defer os.Remove(tf.Name())
defaultUserAgent := ""
asserted := false
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if defaultUserAgent == "" {
defaultUserAgent = r.UserAgent()
} else {
incomingUserAgent := r.UserAgent()
if incomingUserAgent != defaultUserAgent {
t.Fatalf("Expected user agent %s, got: %s", defaultUserAgent, incomingUserAgent)
}
asserted = true
}
}))
req, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatal(err)
}
httpClient := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
},
}
_, err = httpClient.Do(req)
if err != nil {
t.Fatal(err)
}
config := &DownloadConfig{
Url: server.URL,
TargetPath: tf.Name(),
}
client := NewDownloadClient(config)
_, err = client.Get()
if err != nil {
t.Fatal(err)
}
if !asserted {
t.Fatal("User-Agent never observed")
}
}
func TestDownloadClientSetsUserAgent(t *testing.T) {
tf, err := ioutil.TempFile("", "packer")
if err != nil {
t.Fatalf("tempfile error: %s", err)
}
defer os.Remove(tf.Name())
asserted := false
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
asserted = true
if r.UserAgent() != "fancy user agent" {
t.Fatalf("Expected useragent fancy user agent, got: %s", r.UserAgent())
}
}))
config := &DownloadConfig{
Url: server.URL,
TargetPath: tf.Name(),
UserAgent: "fancy user agent",
}
client := NewDownloadClient(config)
_, err = client.Get()
if err != nil {
t.Fatal(err)
}
if !asserted {
t.Fatal("HTTP request never made")
}
}
func TestHashForType(t *testing.T) { func TestHashForType(t *testing.T) {
if h := HashForType("md5"); h == nil { if h := HashForType("md5"); h == nil {
t.Fatalf("md5 hash is nil") t.Fatalf("md5 hash is nil")