diff --git a/builder/azure/arm/builder.go b/builder/azure/arm/builder.go index 27bc32a4a..2381e529f 100644 --- a/builder/azure/arm/builder.go +++ b/builder/azure/arm/builder.go @@ -79,7 +79,7 @@ func (b *Builder) Run(ctx context.Context, ui packer.Ui, hook packer.Hook) (pack b.config.ClientConfig.SubscriptionID, b.config.ResourceGroupName, b.config.StorageAccount, - b.config.ClientConfig.CloudEnvironment, + b.config.ClientConfig.CloudEnvironment(), b.config.SharedGalleryTimeout, spnCloud, spnKeyVault) diff --git a/builder/azure/arm/config_test.go b/builder/azure/arm/config_test.go index 9ed25fffa..b4643ffdc 100644 --- a/builder/azure/arm/config_test.go +++ b/builder/azure/arm/config_test.go @@ -277,8 +277,8 @@ func TestConfigShouldDefaultToPublicCloud(t *testing.T) { t.Errorf("Expected 'CloudEnvironmentName' to default to 'Public', but got '%s'.", c.ClientConfig.CloudEnvironmentName) } - if c.ClientConfig.CloudEnvironment == nil || c.ClientConfig.CloudEnvironment.Name != "AzurePublicCloud" { - t.Errorf("Expected 'cloudEnvironment' to be set to 'AzurePublicCloud', but got '%s'.", c.ClientConfig.CloudEnvironment) + if c.ClientConfig.CloudEnvironment() == nil || c.ClientConfig.CloudEnvironment().Name != "AzurePublicCloud" { + t.Errorf("Expected 'cloudEnvironment' to be set to 'AzurePublicCloud', but got '%s'.", c.ClientConfig.CloudEnvironment()) } } @@ -327,8 +327,8 @@ func TestConfigInstantiatesCorrectAzureEnvironment(t *testing.T) { t.Fatal(err) } - if c.ClientConfig.CloudEnvironment == nil || c.ClientConfig.CloudEnvironment.Name != x.environmentName { - t.Errorf("Expected 'cloudEnvironment' to be set to '%s', but got '%s'.", x.environmentName, c.ClientConfig.CloudEnvironment) + if c.ClientConfig.CloudEnvironment() == nil || c.ClientConfig.CloudEnvironment().Name != x.environmentName { + t.Errorf("Expected 'cloudEnvironment' to be set to '%s', but got '%s'.", x.environmentName, c.ClientConfig.CloudEnvironment()) } } } diff --git a/builder/azure/common/client/config.go b/builder/azure/common/client/config.go index 9a4a736e8..601e7ac25 100644 --- a/builder/azure/common/client/config.go +++ b/builder/azure/common/client/config.go @@ -4,11 +4,12 @@ package client import ( "fmt" - "github.com/hashicorp/packer/builder/azure/common" "os" "strings" "time" + "github.com/hashicorp/packer/builder/azure/common" + "github.com/Azure/go-autorest/autorest/adal" "github.com/Azure/go-autorest/autorest/azure" jwt "github.com/dgrijalva/jwt-go" @@ -28,7 +29,7 @@ type Config struct { // USGovernment. Defaults to Public. Long forms such as // USGovernmentCloud and AzureUSGovernmentCloud are also supported. CloudEnvironmentName string `mapstructure:"cloud_environment_name" required:"false"` - CloudEnvironment *azure.Environment + cloudEnvironment *azure.Environment // Authentication fields @@ -73,6 +74,10 @@ func (c *Config) SetDefaultValues() error { return c.setCloudEnvironment() } +func (c *Config) CloudEnvironment() *azure.Environment { + return c.cloudEnvironment +} + func (c *Config) setCloudEnvironment() error { lookup := map[string]string{ "CHINA": "AzureChinaCloud", @@ -103,7 +108,7 @@ func (c *Config) setCloudEnvironment() error { } env, err := azure.EnvironmentFromName(envName) - c.CloudEnvironment = &env + c.cloudEnvironment = &env return err } @@ -210,22 +215,22 @@ func (c Config) GetServicePrincipalTokens( switch c.authType { case authTypeDeviceLogin: say("Getting tokens using device flow") - auth = NewDeviceFlowOAuthTokenProvider(*c.CloudEnvironment, say, tenantID) + auth = NewDeviceFlowOAuthTokenProvider(*c.cloudEnvironment, say, tenantID) case authTypeMSI: say("Getting tokens using Managed Identity for Azure") - auth = NewMSIOAuthTokenProvider(*c.CloudEnvironment) + auth = NewMSIOAuthTokenProvider(*c.cloudEnvironment) case authTypeClientSecret: say("Getting tokens using client secret") - auth = NewSecretOAuthTokenProvider(*c.CloudEnvironment, c.ClientID, c.ClientSecret, tenantID) + auth = NewSecretOAuthTokenProvider(*c.cloudEnvironment, c.ClientID, c.ClientSecret, tenantID) case authTypeClientCert: say("Getting tokens using client certificate") - auth, err = NewCertOAuthTokenProvider(*c.CloudEnvironment, c.ClientID, c.ClientCertPath, tenantID) + auth, err = NewCertOAuthTokenProvider(*c.cloudEnvironment, c.ClientID, c.ClientCertPath, tenantID) if err != nil { return nil, nil, err } case authTypeClientBearerJWT: say("Getting tokens using client bearer JWT") - auth = NewJWTOAuthTokenProvider(*c.CloudEnvironment, c.ClientID, c.ClientJWT, tenantID) + auth = NewJWTOAuthTokenProvider(*c.cloudEnvironment, c.ClientID, c.ClientJWT, tenantID) default: panic("authType not set, call FillParameters, or set explicitly") } @@ -241,7 +246,7 @@ func (c Config) GetServicePrincipalTokens( } servicePrincipalTokenVault, err = auth.getServicePrincipalTokenWithResource( - strings.TrimRight(c.CloudEnvironment.KeyVaultEndpoint, "/")) + strings.TrimRight(c.cloudEnvironment.KeyVaultEndpoint, "/")) if err != nil { return nil, nil, err } @@ -280,7 +285,7 @@ func (c *Config) FillParameters() error { c.SubscriptionID = subscriptionID } - if c.CloudEnvironment == nil { + if c.cloudEnvironment == nil { err := c.setCloudEnvironment() if err != nil { return err @@ -288,7 +293,7 @@ func (c *Config) FillParameters() error { } if c.TenantID == "" { - tenantID, err := findTenantID(*c.CloudEnvironment, c.SubscriptionID) + tenantID, err := findTenantID(*c.cloudEnvironment, c.SubscriptionID) if err != nil { return err } diff --git a/builder/azure/common/client/config_test.go b/builder/azure/common/client/config_test.go index 3e76b4d2b..a914e25c3 100644 --- a/builder/azure/common/client/config_test.go +++ b/builder/azure/common/client/config_test.go @@ -133,7 +133,7 @@ func Test_ClientConfig_DeviceLogin(t *testing.T) { getEnvOrSkip(t, "AZURE_DEVICE_LOGIN") cfg := Config{ SubscriptionID: getEnvOrSkip(t, "AZURE_SUBSCRIPTION"), - CloudEnvironment: getCloud(), + cloudEnvironment: getCloud(), } assertValid(t, cfg) @@ -164,7 +164,7 @@ func Test_ClientConfig_ClientPassword(t *testing.T) { ClientID: getEnvOrSkip(t, "AZURE_CLIENTID"), ClientSecret: getEnvOrSkip(t, "AZURE_CLIENTSECRET"), TenantID: getEnvOrSkip(t, "AZURE_TENANTID"), - CloudEnvironment: getCloud(), + cloudEnvironment: getCloud(), } assertValid(t, cfg) @@ -194,7 +194,7 @@ func Test_ClientConfig_ClientCert(t *testing.T) { ClientID: getEnvOrSkip(t, "AZURE_CLIENTID"), ClientCertPath: getEnvOrSkip(t, "AZURE_CLIENTCERT"), TenantID: getEnvOrSkip(t, "AZURE_TENANTID"), - CloudEnvironment: getCloud(), + cloudEnvironment: getCloud(), } assertValid(t, cfg) @@ -224,7 +224,7 @@ func Test_ClientConfig_ClientJWT(t *testing.T) { ClientID: getEnvOrSkip(t, "AZURE_CLIENTID"), ClientJWT: getEnvOrSkip(t, "AZURE_CLIENTJWT"), TenantID: getEnvOrSkip(t, "AZURE_TENANTID"), - CloudEnvironment: getCloud(), + cloudEnvironment: getCloud(), } assertValid(t, cfg)