HTTPDownloader uses UserAgent from DownloadConfig
This commit is contained in:
parent
73e57d918f
commit
9e5c0f6c6a
|
@ -43,6 +43,10 @@ type DownloadConfig struct {
|
|||
// for the downloader will be used to verify with this checksum after
|
||||
// it is downloaded.
|
||||
Checksum []byte
|
||||
|
||||
// What to use for the user agent for HTTP requests. If set to "", use the
|
||||
// default user agent provided by Go.
|
||||
UserAgent string
|
||||
}
|
||||
|
||||
// A DownloadClient helps download, verify checksums, etc.
|
||||
|
@ -73,8 +77,8 @@ func HashForType(t string) hash.Hash {
|
|||
func NewDownloadClient(c *DownloadConfig) *DownloadClient {
|
||||
if c.DownloaderMap == nil {
|
||||
c.DownloaderMap = map[string]Downloader{
|
||||
"http": new(HTTPDownloader),
|
||||
"https": new(HTTPDownloader),
|
||||
"http": &HTTPDownloader{userAgent: c.UserAgent},
|
||||
"https": &HTTPDownloader{userAgent: c.UserAgent},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -182,8 +186,9 @@ func (d *DownloadClient) VerifyChecksum(path string) (bool, error) {
|
|||
// HTTPDownloader is an implementation of Downloader that downloads
|
||||
// files over HTTP.
|
||||
type HTTPDownloader struct {
|
||||
progress uint
|
||||
total uint
|
||||
progress uint
|
||||
total uint
|
||||
userAgent string
|
||||
}
|
||||
|
||||
func (*HTTPDownloader) Cancel() {
|
||||
|
@ -197,6 +202,10 @@ func (d *HTTPDownloader) Download(dst io.Writer, src *url.URL) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if d.userAgent != "" {
|
||||
req.Header.Set("User-Agent", d.userAgent)
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
|
|
|
@ -4,6 +4,8 @@ import (
|
|||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
@ -41,6 +43,91 @@ func TestDownloadClient_VerifyChecksum(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestDownloadClientUsesDefaultUserAgent(t *testing.T) {
|
||||
tf, err := ioutil.TempFile("", "packer")
|
||||
if err != nil {
|
||||
t.Fatalf("tempfile error: %s", err)
|
||||
}
|
||||
defer os.Remove(tf.Name())
|
||||
|
||||
defaultUserAgent := ""
|
||||
asserted := false
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if defaultUserAgent == "" {
|
||||
defaultUserAgent = r.UserAgent()
|
||||
} else {
|
||||
incomingUserAgent := r.UserAgent()
|
||||
if incomingUserAgent != defaultUserAgent {
|
||||
t.Fatalf("Expected user agent %s, got: %s", defaultUserAgent, incomingUserAgent)
|
||||
}
|
||||
|
||||
asserted = true
|
||||
}
|
||||
}))
|
||||
|
||||
req, err := http.NewRequest("GET", server.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
},
|
||||
}
|
||||
|
||||
_, err = httpClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := &DownloadConfig{
|
||||
Url: server.URL,
|
||||
TargetPath: tf.Name(),
|
||||
}
|
||||
|
||||
client := NewDownloadClient(config)
|
||||
_, err = client.Get()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !asserted {
|
||||
t.Fatal("User-Agent never observed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadClientSetsUserAgent(t *testing.T) {
|
||||
tf, err := ioutil.TempFile("", "packer")
|
||||
if err != nil {
|
||||
t.Fatalf("tempfile error: %s", err)
|
||||
}
|
||||
defer os.Remove(tf.Name())
|
||||
|
||||
asserted := false
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
asserted = true
|
||||
if r.UserAgent() != "fancy user agent" {
|
||||
t.Fatalf("Expected useragent fancy user agent, got: %s", r.UserAgent())
|
||||
}
|
||||
}))
|
||||
config := &DownloadConfig{
|
||||
Url: server.URL,
|
||||
TargetPath: tf.Name(),
|
||||
UserAgent: "fancy user agent",
|
||||
}
|
||||
|
||||
client := NewDownloadClient(config)
|
||||
_, err = client.Get()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !asserted {
|
||||
t.Fatal("HTTP request never made")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashForType(t *testing.T) {
|
||||
if h := HashForType("md5"); h == nil {
|
||||
t.Fatalf("md5 hash is nil")
|
||||
|
|
Loading…
Reference in New Issue