diff --git a/builder/azure/arm/builder.go b/builder/azure/arm/builder.go index 48c19263e..f6f3a1f33 100644 --- a/builder/azure/arm/builder.go +++ b/builder/azure/arm/builder.go @@ -59,8 +59,10 @@ func (b *Builder) Run(ui packer.Ui, hook packer.Hook, cache packer.Cache) (packe b.ctxCancel = cancel defer cancel() - if err := newConfigRetriever().FillParameters(b.config); err != nil { - return nil, err + if !b.config.useMSI() { + if err := newConfigRetriever().FillParameters(b.config); err != nil { + return nil, err + } } log.Print(":: Configuration") @@ -74,6 +76,12 @@ func (b *Builder) Run(ui packer.Ui, hook packer.Hook, cache packer.Cache) (packe return nil, err } + if b.config.useMSI() { + if err := newConfigRetriever().FillParameters(b.config); err != nil { + return nil, err + } + } + ui.Message("Creating Azure Resource Manager (ARM) client ...") azureClient, err := NewAzureClient( b.config.SubscriptionID, diff --git a/builder/azure/arm/config_retriever.go b/builder/azure/arm/config_retriever.go index 3e8a3e655..4760041d8 100644 --- a/builder/azure/arm/config_retriever.go +++ b/builder/azure/arm/config_retriever.go @@ -8,6 +8,10 @@ package arm // 1. TenantID import ( + "encoding/json" + "io/ioutil" + "net/http" + "github.com/Azure/go-autorest/autorest/azure" "github.com/hashicorp/packer/builder/azure/common" ) @@ -24,6 +28,14 @@ func newConfigRetriever() configRetriever { } func (cr configRetriever) FillParameters(c *Config) error { + if c.SubscriptionID == "" { + subscriptionID, err := cr.getSubscriptionFromIMDS() + if err != nil { + return err + } + c.SubscriptionID = subscriptionID + } + if c.TenantID == "" { tenantID, err := cr.findTenantID(*c.cloudEnvironment, c.SubscriptionID) if err != nil { @@ -34,3 +46,30 @@ func (cr configRetriever) FillParameters(c *Config) error { return nil } + +func (cr configRetriever) getSubscriptionFromIMDS() (string, error) { + client := &http.Client{} + + req, _ := http.NewRequest("GET", "http://169.254.169.254/metadata/instance/compute", nil) + req.Header.Add("Metadata", "True") + + q := req.URL.Query() + q.Add("format", "json") + q.Add("api-version", "2017-08-01") + + req.URL.RawQuery = q.Encode() + resp, err := client.Do(req) + if err != nil { + return "", err + } + + defer resp.Body.Close() + resp_body, _ := ioutil.ReadAll(resp.Body) + result := map[string]string{} + err = json.Unmarshal(resp_body, &result) + if err != nil { + return "", err + } + + return result["subscriptionId"], nil +}