Merge pull request #803 from mark-rushakoff/useragent

common: set user agent in downloader
This commit is contained in:
Mitchell Hashimoto 2014-01-19 15:46:45 -08:00
commit 3f77b2c592
4 changed files with 111 additions and 10 deletions

View File

@ -43,6 +43,10 @@ type DownloadConfig struct {
// for the downloader will be used to verify with this checksum after
// it is downloaded.
Checksum []byte
// What to use for the user agent for HTTP requests. If set to "", use the
// default user agent provided by Go.
UserAgent string
}
// A DownloadClient helps download, verify checksums, etc.
@ -73,8 +77,8 @@ func HashForType(t string) hash.Hash {
func NewDownloadClient(c *DownloadConfig) *DownloadClient {
if c.DownloaderMap == nil {
c.DownloaderMap = map[string]Downloader{
"http": new(HTTPDownloader),
"https": new(HTTPDownloader),
"http": &HTTPDownloader{userAgent: c.UserAgent},
"https": &HTTPDownloader{userAgent: c.UserAgent},
}
}
@ -182,8 +186,9 @@ func (d *DownloadClient) VerifyChecksum(path string) (bool, error) {
// HTTPDownloader is an implementation of Downloader that downloads
// files over HTTP.
type HTTPDownloader struct {
progress uint
total uint
progress uint
total uint
userAgent string
}
func (*HTTPDownloader) Cancel() {
@ -197,6 +202,10 @@ func (d *HTTPDownloader) Download(dst io.Writer, src *url.URL) error {
return err
}
if d.userAgent != "" {
req.Header.Set("User-Agent", d.userAgent)
}
httpClient := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,

View File

@ -4,6 +4,8 @@ import (
"crypto/md5"
"encoding/hex"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"
)
@ -41,6 +43,91 @@ func TestDownloadClient_VerifyChecksum(t *testing.T) {
}
}
func TestDownloadClientUsesDefaultUserAgent(t *testing.T) {
tf, err := ioutil.TempFile("", "packer")
if err != nil {
t.Fatalf("tempfile error: %s", err)
}
defer os.Remove(tf.Name())
defaultUserAgent := ""
asserted := false
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if defaultUserAgent == "" {
defaultUserAgent = r.UserAgent()
} else {
incomingUserAgent := r.UserAgent()
if incomingUserAgent != defaultUserAgent {
t.Fatalf("Expected user agent %s, got: %s", defaultUserAgent, incomingUserAgent)
}
asserted = true
}
}))
req, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatal(err)
}
httpClient := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
},
}
_, err = httpClient.Do(req)
if err != nil {
t.Fatal(err)
}
config := &DownloadConfig{
Url: server.URL,
TargetPath: tf.Name(),
}
client := NewDownloadClient(config)
_, err = client.Get()
if err != nil {
t.Fatal(err)
}
if !asserted {
t.Fatal("User-Agent never observed")
}
}
func TestDownloadClientSetsUserAgent(t *testing.T) {
tf, err := ioutil.TempFile("", "packer")
if err != nil {
t.Fatalf("tempfile error: %s", err)
}
defer os.Remove(tf.Name())
asserted := false
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
asserted = true
if r.UserAgent() != "fancy user agent" {
t.Fatalf("Expected useragent fancy user agent, got: %s", r.UserAgent())
}
}))
config := &DownloadConfig{
Url: server.URL,
TargetPath: tf.Name(),
UserAgent: "fancy user agent",
}
client := NewDownloadClient(config)
_, err = client.Get()
if err != nil {
t.Fatal(err)
}
if !asserted {
t.Fatal("HTTP request never made")
}
}
func TestHashForType(t *testing.T) {
if h := HashForType("md5"); h == nil {
t.Fatalf("md5 hash is nil")

View File

@ -70,6 +70,7 @@ func (s *StepDownload) Run(state multistep.StateBag) multistep.StepAction {
CopyFile: false,
Hash: HashForType(s.ChecksumType),
Checksum: checksum,
UserAgent: packer.VersionString(),
}
path, err, retry := s.download(config, state)

View File

@ -31,6 +31,15 @@ func (versionCommand) Run(env Environment, args []string) int {
env.Ui().Machine("version-prelease", VersionPrerelease)
env.Ui().Machine("version-commit", GitCommit)
env.Ui().Say(VersionString())
return 0
}
func (versionCommand) Synopsis() string {
return "print Packer version"
}
func VersionString() string {
var versionString bytes.Buffer
fmt.Fprintf(&versionString, "Packer v%s", Version)
if VersionPrerelease != "" {
@ -41,10 +50,5 @@ func (versionCommand) Run(env Environment, args []string) int {
}
}
env.Ui().Say(versionString.String())
return 0
}
func (versionCommand) Synopsis() string {
return "print Packer version"
return versionString.String()
}