diff --git a/builder/amazon/common/access_config.go b/builder/amazon/common/access_config.go index 5c05f4280..66b595300 100644 --- a/builder/amazon/common/access_config.go +++ b/builder/amazon/common/access_config.go @@ -12,6 +12,7 @@ import ( "github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go/service/ec2/ec2iface" "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/packer/template/interpolate" ) @@ -29,6 +30,8 @@ type AccessConfig struct { SkipMetadataApiCheck bool `mapstructure:"skip_metadata_api_check"` Token string `mapstructure:"token"` session *session.Session + + getEC2Connection func() ec2iface.EC2API } // Config returns a valid aws.Config object for access to AWS services, or @@ -148,12 +151,7 @@ func (c *AccessConfig) Prepare(ctx *interpolate.Context) []error { } if c.RawRegion != "" && !c.SkipValidation { - sess, err := c.Session() - if err != nil { - errs = append(errs, err) - } - ec2conn := ec2.New(sess) - err = ValidateRegion(c.RawRegion, ec2conn) + err := c.ValidateRegion(c.RawRegion) if err != nil { errs = append(errs, fmt.Errorf("error validating region: %s", err.Error())) } @@ -161,3 +159,14 @@ func (c *AccessConfig) Prepare(ctx *interpolate.Context) []error { return errs } + +func (c *AccessConfig) NewEC2Connection() (ec2iface.EC2API, error) { + if c.getEC2Connection != nil { + return c.getEC2Connection(), nil + } + sess, err := c.Session() + if err != nil { + return nil, err + } + return ec2.New(sess), nil +} diff --git a/builder/amazon/common/access_config_test.go b/builder/amazon/common/access_config_test.go index c94d78bff..489bf08df 100644 --- a/builder/amazon/common/access_config_test.go +++ b/builder/amazon/common/access_config_test.go @@ -5,31 +5,34 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/ec2/ec2iface" ) func testAccessConfig() *AccessConfig { - return &AccessConfig{} + return &AccessConfig{ + getEC2Connection: func() ec2iface.EC2API { + return &mockEC2Client{} + }, + } } func TestAccessConfigPrepare_Region(t *testing.T) { c := testAccessConfig() - mockConn := &mockEC2Client{} - c.RawRegion = "us-east-12" - err := ValidateRegion(c.RawRegion, mockConn) + err := c.ValidateRegion(c.RawRegion) if err == nil { t.Fatalf("should have region validation err: %s", c.RawRegion) } c.RawRegion = "us-east-1" - err = ValidateRegion(c.RawRegion, mockConn) + err = c.ValidateRegion(c.RawRegion) if err != nil { t.Fatalf("shouldn't have region validation err: %s", c.RawRegion) } c.RawRegion = "custom" - err = ValidateRegion(c.RawRegion, mockConn) + err = c.ValidateRegion(c.RawRegion) if err == nil { t.Fatalf("should have region validation err: %s", c.RawRegion) } diff --git a/builder/amazon/common/ami_config.go b/builder/amazon/common/ami_config.go index ceb7dbe57..5fbac557a 100644 --- a/builder/amazon/common/ami_config.go +++ b/builder/amazon/common/ami_config.go @@ -4,8 +4,6 @@ import ( "fmt" "log" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ec2/ec2iface" "github.com/hashicorp/packer/template/interpolate" ) @@ -58,16 +56,7 @@ func (c *AMIConfig) Prepare(accessConfig *AccessConfig, ctx *interpolate.Context } } - var ec2conn *ec2.EC2 - if !c.AMISkipRegionValidation { - sess, err := accessConfig.Session() - if err != nil { - errs = append(errs, err) - } - ec2conn = ec2.New(sess) - } - - errs = c.prepareRegions(ec2conn, accessConfig, errs) + errs = append(errs, c.prepareRegions(accessConfig)...) if len(c.AMIUsers) > 0 && c.AMIEncryptBootVolume { errs = append(errs, fmt.Errorf("Cannot share AMI with encrypted boot volume")) @@ -105,8 +94,16 @@ func (c *AMIConfig) Prepare(accessConfig *AccessConfig, ctx *interpolate.Context return nil } -func (c *AMIConfig) prepareRegions(ec2conn ec2iface.EC2API, accessConfig *AccessConfig, errs []error) []error { +func (c *AMIConfig) prepareRegions(accessConfig *AccessConfig) (errs []error) { if len(c.AMIRegions) > 0 { + if !c.AMISkipRegionValidation { + // Verify the regions are real + err := accessConfig.ValidateRegion(c.AMIRegions...) + if err != nil { + errs = append(errs, fmt.Errorf("error validating regions: %v", err)) + } + } + regionSet := make(map[string]struct{}) regions := make([]string, 0, len(c.AMIRegions)) @@ -119,14 +116,6 @@ func (c *AMIConfig) prepareRegions(ec2conn ec2iface.EC2API, accessConfig *Access // Mark that we saw the region regionSet[region] = struct{}{} - if !c.AMISkipRegionValidation { - // Verify the region is real - err := ValidateRegion(region, ec2conn) - if err != nil { - errs = append(errs, fmt.Errorf("error validating region: %s", err.Error())) - } - } - // Make sure that if we have region_kms_key_ids defined, // the regions in ami_regions are also in region_kms_key_ids if len(c.AMIRegionKMSKeyIDs) > 0 { diff --git a/builder/amazon/common/ami_config_test.go b/builder/amazon/common/ami_config_test.go index 4aee3204c..7623d6adb 100644 --- a/builder/amazon/common/ami_config_test.go +++ b/builder/amazon/common/ami_config_test.go @@ -17,14 +17,14 @@ func testAMIConfig() *AMIConfig { } func getFakeAccessConfig(region string) *AccessConfig { - return &AccessConfig{ - RawRegion: region, - } + c := testAccessConfig() + c.RawRegion = region + return c } func TestAMIConfigPrepare_name(t *testing.T) { c := testAMIConfig() - accessConf := getFakeAccessConfig("wherever") + accessConf := testAccessConfig() c.AMISkipRegionValidation = true if err := c.Prepare(accessConf, nil); err != nil { t.Fatalf("shouldn't have err: %s", err) @@ -57,8 +57,9 @@ func TestAMIConfigPrepare_regions(t *testing.T) { var errs []error var err error + accessConf := testAccessConfig() mockConn := &mockEC2Client{} - if errs = c.prepareRegions(mockConn, nil, errs); len(errs) > 0 { + if errs = c.prepareRegions(accessConf); len(errs) > 0 { t.Fatalf("shouldn't have err: %#v", errs) } @@ -67,18 +68,18 @@ func TestAMIConfigPrepare_regions(t *testing.T) { if err != nil { t.Fatalf("shouldn't have err: %s", err.Error()) } - if errs = c.prepareRegions(mockConn, nil, errs); len(errs) > 0 { + if errs = c.prepareRegions(accessConf); len(errs) > 0 { t.Fatalf("shouldn't have err: %#v", errs) } c.AMIRegions = []string{"foo"} - if errs = c.prepareRegions(mockConn, nil, errs); len(errs) == 0 { + if errs = c.prepareRegions(accessConf); len(errs) == 0 { t.Fatal("should have error") } errs = errs[:0] c.AMIRegions = []string{"us-east-1", "us-west-1", "us-east-1"} - if errs = c.prepareRegions(mockConn, nil, errs); len(errs) > 0 { + if errs = c.prepareRegions(accessConf); len(errs) > 0 { t.Fatalf("bad: %s", errs[0]) } @@ -89,7 +90,7 @@ func TestAMIConfigPrepare_regions(t *testing.T) { c.AMIRegions = []string{"custom"} c.AMISkipRegionValidation = true - if errs = c.prepareRegions(mockConn, nil, errs); len(errs) > 0 { + if errs = c.prepareRegions(accessConf); len(errs) > 0 { t.Fatal("shouldn't have error") } c.AMISkipRegionValidation = false @@ -100,7 +101,7 @@ func TestAMIConfigPrepare_regions(t *testing.T) { "us-west-1": "789-012-3456", "us-east-2": "456-789-0123", } - if errs = c.prepareRegions(mockConn, nil, errs); len(errs) > 0 { + if errs = c.prepareRegions(accessConf); len(errs) > 0 { t.Fatal(fmt.Sprintf("shouldn't have error: %s", errs[0])) } @@ -110,7 +111,7 @@ func TestAMIConfigPrepare_regions(t *testing.T) { "us-west-1": "789-012-3456", "us-east-2": "", } - if errs = c.prepareRegions(mockConn, nil, errs); len(errs) > 0 { + if errs = c.prepareRegions(accessConf); len(errs) > 0 { t.Fatal("should have passed; we are able to use default KMS key if not sharing") } @@ -121,7 +122,7 @@ func TestAMIConfigPrepare_regions(t *testing.T) { "us-west-1": "789-012-3456", "us-east-2": "", } - if errs = c.prepareRegions(mockConn, nil, errs); len(errs) > 0 { + if errs = c.prepareRegions(accessConf); len(errs) > 0 { t.Fatal("should have an error b/c can't use default KMS key if sharing") } @@ -131,7 +132,7 @@ func TestAMIConfigPrepare_regions(t *testing.T) { "us-west-1": "789-012-3456", "us-east-2": "456-789-0123", } - if errs = c.prepareRegions(mockConn, nil, errs); len(errs) > 0 { + if errs = c.prepareRegions(accessConf); len(errs) > 0 { t.Fatal("should have error b/c theres a region in the key map that isn't in ami_regions") } @@ -142,7 +143,6 @@ func TestAMIConfigPrepare_regions(t *testing.T) { } c.AMISkipRegionValidation = true - accessConf := getFakeAccessConfig("wherever") if err := c.Prepare(accessConf, nil); err == nil { t.Fatal("should have error b/c theres a region in in ami_regions that isn't in the key map") } @@ -156,7 +156,7 @@ func TestAMIConfigPrepare_regions(t *testing.T) { "us-east-1": "123-456-7890", "us-west-1": "", } - if errs = c.prepareRegions(mockConn, nil, errs); len(errs) > 0 { + if errs = c.prepareRegions(accessConf); len(errs) > 0 { t.Fatal("should have error b/c theres a region in in ami_regions that isn't in the key map") } @@ -164,7 +164,7 @@ func TestAMIConfigPrepare_regions(t *testing.T) { accessConf = getFakeAccessConfig("us-east-1") c.AMIRegions = []string{"us-east-1", "us-west-1", "us-east-2"} c.AMIRegionKMSKeyIDs = nil - if errs = c.prepareRegions(mockConn, accessConf, errs); len(errs) > 0 { + if errs = c.prepareRegions(accessConf); len(errs) > 0 { t.Fatal("should allow user to have the raw region in ami_regions") } @@ -176,7 +176,7 @@ func TestAMIConfigPrepare_Share_EncryptedBoot(t *testing.T) { c.AMIUsers = []string{"testAccountID"} c.AMIEncryptBootVolume = true - accessConf := getFakeAccessConfig("wherever") + accessConf := testAccessConfig() c.AMIKmsKeyId = "" if err := c.Prepare(accessConf, nil); err == nil { @@ -193,7 +193,7 @@ func TestAMINameValidation(t *testing.T) { c := testAMIConfig() c.AMISkipRegionValidation = true - accessConf := getFakeAccessConfig("wherever") + accessConf := testAccessConfig() c.AMIName = "aa" if err := c.Prepare(accessConf, nil); err == nil { diff --git a/builder/amazon/common/regions.go b/builder/amazon/common/regions.go index 6eee8c54a..3bdda86da 100644 --- a/builder/amazon/common/regions.go +++ b/builder/amazon/common/regions.go @@ -2,6 +2,7 @@ package common import ( "fmt" + "github.com/aws/aws-sdk-go/service/ec2/ec2iface" ) @@ -20,17 +21,33 @@ func listEC2Regions(ec2conn ec2iface.EC2API) ([]string, error) { // ValidateRegion returns true if the supplied region is a valid AWS // region and false if it's not. -func ValidateRegion(region string, ec2conn ec2iface.EC2API) error { - regions, err := listEC2Regions(ec2conn) +func (c *AccessConfig) ValidateRegion(regions ...string) error { + ec2conn, err := c.NewEC2Connection() if err != nil { return err } - for _, valid := range regions { - if region == valid { - return nil + validRegions, err := listEC2Regions(ec2conn) + if err != nil { + return err + } + + var invalidRegions []string + for _, region := range regions { + found := false + for _, validRegion := range validRegions { + if region == validRegion { + found = true + break + } + } + if !found { + invalidRegions = append(invalidRegions, region) } } - return fmt.Errorf("Invalid region: %s", region) + if len(invalidRegions) > 0 { + return fmt.Errorf("Invalid region(s): %v", invalidRegions) + } + return nil }