diff --git a/builder/azure/arm/README.md b/builder/azure/arm/README.md index 32440f6d0..9f1c3aec8 100644 --- a/builder/azure/arm/README.md +++ b/builder/azure/arm/README.md @@ -6,21 +6,91 @@ Please see the [overview](https://azure.microsoft.com/en-us/documentation/articles/resource-group-overview/) for more information about ARM as well as the benefit of ARM. -## Getting Started +## Device Login vs. Service Principal Name (SPN) -The ARM APIs use OAUTH to authenticate, so you must create a Service -Principal. The following articles are a good starting points. +There are two ways to get started with packer-azure. The simplest is device login, and only requires a Subscription ID. +Device login is only supported for Linux based VMs. The second is the use of an SPN. We recommend the device login +approach for those who are first time users, and just want to ''kick the tires.'' We recommend the SPN approach if you +intend to automate Packer, or you are deploying Windows VMs. + +## Device Login + +A sample template for device login is show below. There are three pieces of information +you must provide to enable device login mode. + + 1. SubscriptionID + 1. Resource Group - parent resource group that Packer uses to build an image. + 1. Storage Account - storage account where the image will be placed. + +> Device login mode is enabled by not setting client_id, client_secret, and tenant_id. + +The device login flow asks that you open a web browser, navigate to http://aka.ms/devicelogin, and input the supplied +code. This authorizes the Packer for Azure application to act on your behalf. An OAuth token will be created, and +stored in the user's home directory (~/.azure/packer/oauth-TenantID.json, and TenantID will be replaced with the actual +Tenant ID). This token is used if it exists, and refreshed as necessary. + +```json +{ + "variables": { + "sid": "your_subscription_id", + "rgn": "your_resource_group", + "sa": "your_storage_account" + }, + "builders": [ + { + "type": "azure-arm", + + "subscription_id": "{{user `sid`}}", + + "resource_group_name": "{{user `rgn`}}", + "storage_account": "{{user `sa`}}", + + "capture_container_name": "images", + "capture_name_prefix": "packer", + + "os_type": "Linux", + "image_publisher": "Canonical", + "image_offer": "UbuntuServer", + "image_sku": "14.04.3-LTS", + + "location": "South Central US", + "vm_size": "Standard_A2" + } + ], + "provisioners": [ + { + "execute_command": "chmod +x {{ .Path }}; {{ .Vars }} sudo -E sh '{{ .Path }}'", + "inline": [ + "apt-get update", + "apt-get upgrade -y", + + "/usr/sbin/waagent -force -deprovision+user && export HISTSIZE=0 && sync" + ], + "inline_shebang": "/bin/sh -x", + "type": "shell" + } + ] +} +``` + +## Service Principal Name + +The ARM APIs use OAUTH to authenticate, and requires an SPN. The following articles +are a good starting points for creating a new SPN. * [Automating Azure on your CI server using a Service Principal](http://blog.davidebbo.com/2014/12/azure-service-principal.html) * [Authenticating a service principal with Azure Resource Manager](https://azure.microsoft.com/en-us/documentation/articles/resource-group-authenticate-service-principal/) -There are three pieces of configuration you will need as a result of -creating a Service Principal. +There are three (four in the case of Windows) pieces of configuration you need to note +after creating an SPN. 1. Client ID (aka Service Principal ID) 1. Client Secret (aka Service Principal generated key) 1. Client Tenant (aka Azure Active Directory tenant that owns the Service Principal) + 1. Object ID (Windows only) - a certificate is used to authenticate WinRM access, and the certificate is injected into + the VM using Azure Key Vault. Access to the key vault is protected by an ACL associated with the SPN's ObjectID. + Linux does not need nor use a key vault, so there's no need to know the ObjectID. You will also need the following. @@ -38,7 +108,8 @@ for more details. * Owner > NOTE: the Owner role is too powerful, and more explicit set of roles -> is TBD. Issue #183 is tracking this work. +> is TBD. Issue #183 is tracking this work. Permissions can be scoped to +> a specific resource group to further limit access. ### Sample Ubuntu @@ -65,26 +136,30 @@ Azure for ARM builder. "subscription_id": "{{user `sid`}}", "tenant_id": "{{user `tid`}}", - "capture_container_name": "images", - "capture_name_prefix": "my_prefix", + "resource_group_name": "{{user `rgn`}}", + "storage_account": "{{user `sa`}}", + "capture_container_name": "images", + "capture_name_prefix": "packer", + + "os_type": "Linux", "image_publisher": "Canonical", "image_offer": "UbuntuServer", "image_sku": "14.04.3-LTS", "location": "South Central US", - "resource_group_name": "{{user `rgn`}}", - "storage_account": "{{user `sa`}}", - - "vm_size": "Standard_A1" + "vm_size": "Standard_A2" } ], "provisioners": [ { "execute_command": "chmod +x {{ .Path }}; {{ .Vars }} sudo -E sh '{{ .Path }}'", "inline": [ - "sudo apt-get update", + "apt-get update", + "apt-get upgrade -y", + + "/usr/sbin/waagent -force -deprovision+user && export HISTSIZE=0 && sync" ], "inline_shebang": "/bin/sh -x", "type": "shell" diff --git a/builder/azure/arm/artifact.go b/builder/azure/arm/artifact.go index 9dbe1f21a..23d372497 100644 --- a/builder/azure/arm/artifact.go +++ b/builder/azure/arm/artifact.go @@ -1,36 +1,112 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm +import ( + "bytes" + "fmt" + "net/url" + "path" + "strings" +) + const ( BuilderId = "Azure.ResourceManagement.VMImage" ) -type artifact struct { - name string +type Artifact struct { + StorageAccountLocation string + OSDiskUri string + TemplateUri string + OSDiskUriReadOnlySas string + TemplateUriReadOnlySas string } -func (*artifact) BuilderId() string { +func NewArtifact(template *CaptureTemplate, getSasUrl func(name string) string) (*Artifact, error) { + if template == nil { + return nil, fmt.Errorf("nil capture template") + } + + if len(template.Resources) != 1 { + return nil, fmt.Errorf("malformed capture template, expected one resource") + } + + vhdUri, err := url.Parse(template.Resources[0].Properties.StorageProfile.OSDisk.Image.Uri) + if err != nil { + return nil, err + } + + templateUri, err := storageUriToTemplateUri(vhdUri) + if err != nil { + return nil, err + } + + return &Artifact{ + OSDiskUri: vhdUri.String(), + OSDiskUriReadOnlySas: getSasUrl(getStorageUrlPath(vhdUri)), + TemplateUri: templateUri.String(), + TemplateUriReadOnlySas: getSasUrl(getStorageUrlPath(templateUri)), + + StorageAccountLocation: template.Resources[0].Location, + }, nil +} + +func getStorageUrlPath(u *url.URL) string { + parts := strings.Split(u.Path, "/") + return strings.Join(parts[3:], "/") +} + +func storageUriToTemplateUri(su *url.URL) (*url.URL, error) { + // packer-osDisk.4085bb15-3644-4641-b9cd-f575918640b4.vhd -> 4085bb15-3644-4641-b9cd-f575918640b4 + filename := path.Base(su.Path) + parts := strings.Split(filename, ".") + + if len(parts) < 3 { + return nil, fmt.Errorf("malformed URL") + } + + // packer-osDisk.4085bb15-3644-4641-b9cd-f575918640b4.vhd -> packer + prefixParts := strings.Split(parts[0], "-") + prefix := strings.Join(prefixParts[:len(prefixParts)-1], "-") + + templateFilename := fmt.Sprintf("%s-vmTemplate.%s.json", prefix, parts[1]) + + // https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/packer-osDisk.4085bb15-3644-4641-b9cd-f575918640b4.vhd" + // -> + // https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/packer-vmTemplate.4085bb15-3644-4641-b9cd-f575918640b4.json" + return url.Parse(strings.Replace(su.String(), filename, templateFilename, 1)) +} + +func (*Artifact) BuilderId() string { return BuilderId } -func (*artifact) Files() []string { +func (*Artifact) Files() []string { return []string{} } -func (*artifact) Id() string { +func (*Artifact) Id() string { return "" } -func (*artifact) State(name string) interface{} { +func (*Artifact) State(name string) interface{} { return nil } -func (*artifact) String() string { - return "{}" +func (a *Artifact) String() string { + var buf bytes.Buffer + + buf.WriteString(fmt.Sprintf("%s:\n\n", a.BuilderId())) + buf.WriteString(fmt.Sprintf("StorageAccountLocation: %s\n", a.StorageAccountLocation)) + buf.WriteString(fmt.Sprintf("OSDiskUri: %s\n", a.OSDiskUri)) + buf.WriteString(fmt.Sprintf("OSDiskUriReadOnlySas: %s\n", a.OSDiskUriReadOnlySas)) + buf.WriteString(fmt.Sprintf("TemplateUri: %s\n", a.TemplateUri)) + buf.WriteString(fmt.Sprintf("TemplateUriReadOnlySas: %s\n", a.TemplateUriReadOnlySas)) + + return buf.String() } -func (*artifact) Destroy() error { +func (*Artifact) Destroy() error { return nil } diff --git a/builder/azure/arm/artifact_test.go b/builder/azure/arm/artifact_test.go new file mode 100644 index 000000000..a43533bd0 --- /dev/null +++ b/builder/azure/arm/artifact_test.go @@ -0,0 +1,155 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. + +package arm + +import ( + "fmt" + "strings" + "testing" +) + +func getFakeSasUrl(name string) string { + return fmt.Sprintf("SAS-%s", name) +} + +func TestArtifactString(t *testing.T) { + template := CaptureTemplate{ + Resources: []CaptureResources{ + CaptureResources{ + Properties: CaptureProperties{ + StorageProfile: CaptureStorageProfile{ + OSDisk: CaptureDisk{ + Image: CaptureUri{ + Uri: "https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/packer-osDisk.4085bb15-3644-4641-b9cd-f575918640b4.vhd", + }, + }, + }, + }, + Location: "southcentralus", + }, + }, + } + + artifact, err := NewArtifact(&template, getFakeSasUrl) + if err != nil { + t.Fatalf("err=%s", err) + } + + testSubject := artifact.String() + if !strings.Contains(testSubject, "OSDiskUri: https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/packer-osDisk.4085bb15-3644-4641-b9cd-f575918640b4.vhd") { + t.Errorf("Expected String() output to contain OSDiskUri") + } + if !strings.Contains(testSubject, "OSDiskUriReadOnlySas: SAS-Images/images/packer-osDisk.4085bb15-3644-4641-b9cd-f575918640b4.vhd") { + t.Errorf("Expected String() output to contain OSDiskUriReadOnlySas") + } + if !strings.Contains(testSubject, "TemplateUri: https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/packer-vmTemplate.4085bb15-3644-4641-b9cd-f575918640b4.json") { + t.Errorf("Expected String() output to contain TemplateUri") + } + if !strings.Contains(testSubject, "TemplateUriReadOnlySas: SAS-Images/images/packer-vmTemplate.4085bb15-3644-4641-b9cd-f575918640b4.json") { + t.Errorf("Expected String() output to contain TemplateUriReadOnlySas") + } + if !strings.Contains(testSubject, "StorageAccountLocation: southcentralus") { + t.Errorf("Expected String() output to contain StorageAccountLocation") + } +} + +func TestArtifactProperties(t *testing.T) { + template := CaptureTemplate{ + Resources: []CaptureResources{ + CaptureResources{ + Properties: CaptureProperties{ + StorageProfile: CaptureStorageProfile{ + OSDisk: CaptureDisk{ + Image: CaptureUri{ + Uri: "https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/packer-osDisk.4085bb15-3644-4641-b9cd-f575918640b4.vhd", + }, + }, + }, + }, + Location: "southcentralus", + }, + }, + } + + testSubject, err := NewArtifact(&template, getFakeSasUrl) + if err != nil { + t.Fatalf("err=%s", err) + } + + if testSubject.OSDiskUri != "https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/packer-osDisk.4085bb15-3644-4641-b9cd-f575918640b4.vhd" { + t.Errorf("Expected template to be 'https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/packer-osDisk.4085bb15-3644-4641-b9cd-f575918640b4.vhd', but got %s", testSubject.OSDiskUri) + } + if testSubject.OSDiskUriReadOnlySas != "SAS-Images/images/packer-osDisk.4085bb15-3644-4641-b9cd-f575918640b4.vhd" { + t.Errorf("Expected template to be 'SAS-Images/images/packer-osDisk.4085bb15-3644-4641-b9cd-f575918640b4.vhd', but got %s", testSubject.OSDiskUriReadOnlySas) + } + if testSubject.TemplateUri != "https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/packer-vmTemplate.4085bb15-3644-4641-b9cd-f575918640b4.json" { + t.Errorf("Expected template to be 'https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/packer-vmTemplate.4085bb15-3644-4641-b9cd-f575918640b4.json', but got %s", testSubject.TemplateUri) + } + if testSubject.TemplateUriReadOnlySas != "SAS-Images/images/packer-vmTemplate.4085bb15-3644-4641-b9cd-f575918640b4.json" { + t.Errorf("Expected template to be 'SAS-Images/images/packer-vmTemplate.4085bb15-3644-4641-b9cd-f575918640b4.json', but got %s", testSubject.TemplateUriReadOnlySas) + } + if testSubject.StorageAccountLocation != "southcentralus" { + t.Errorf("Expected StorageAccountLocation to be 'southcentral', but got %s", testSubject.StorageAccountLocation) + } +} + +func TestArtifactOverHypenatedCaptureUri(t *testing.T) { + template := CaptureTemplate{ + Resources: []CaptureResources{ + CaptureResources{ + Properties: CaptureProperties{ + StorageProfile: CaptureStorageProfile{ + OSDisk: CaptureDisk{ + Image: CaptureUri{ + Uri: "https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/pac-ker-osDisk.4085bb15-3644-4641-b9cd-f575918640b4.vhd", + }, + }, + }, + }, + Location: "southcentralus", + }, + }, + } + + testSubject, err := NewArtifact(&template, getFakeSasUrl) + if err != nil { + t.Fatalf("err=%s", err) + } + + if testSubject.TemplateUri != "https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/pac-ker-vmTemplate.4085bb15-3644-4641-b9cd-f575918640b4.json" { + t.Errorf("Expected template to be 'https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/pac-ker-vmTemplate.4085bb15-3644-4641-b9cd-f575918640b4.json', but got %s", testSubject.TemplateUri) + } +} + +func TestArtifactRejectMalformedTemplates(t *testing.T) { + template := CaptureTemplate{} + + _, err := NewArtifact(&template, getFakeSasUrl) + if err == nil { + t.Fatalf("Expected artifact creation to fail, but it succeeded.") + } +} + +func TestArtifactRejectMalformedStorageUri(t *testing.T) { + template := CaptureTemplate{ + Resources: []CaptureResources{ + CaptureResources{ + Properties: CaptureProperties{ + StorageProfile: CaptureStorageProfile{ + OSDisk: CaptureDisk{ + Image: CaptureUri{ + Uri: "bark", + }, + }, + }, + }, + }, + }, + } + + _, err := NewArtifact(&template, getFakeSasUrl) + if err == nil { + t.Fatalf("Expected artifact creation to fail, but it succeeded.") + } +} diff --git a/builder/azure/arm/authenticate.go b/builder/azure/arm/authenticate.go new file mode 100644 index 000000000..ba4d05754 --- /dev/null +++ b/builder/azure/arm/authenticate.go @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. + +package arm + +import "github.com/Azure/go-autorest/autorest/azure" + +type Authenticate struct { + env azure.Environment + clientID string + clientSecret string + tenantID string +} + +func NewAuthenticate(env azure.Environment, clientID, clientSecret, tenantID string) *Authenticate { + return &Authenticate{ + env: env, + clientID: clientID, + clientSecret: clientSecret, + tenantID: tenantID, + } +} + +func (a *Authenticate) getServicePrincipalToken() (*azure.ServicePrincipalToken, error) { + return a.getServicePrincipalTokenWithResource(a.env.ResourceManagerEndpoint) +} + +func (a *Authenticate) getServicePrincipalTokenWithResource(resource string) (*azure.ServicePrincipalToken, error) { + oauthConfig, err := newOAuthConfigWithTenant(a.tenantID) + if err != nil { + return nil, err + } + + spt, err := azure.NewServicePrincipalToken( + *oauthConfig, + a.clientID, + a.clientSecret, + resource) + + return spt, err +} + +func newOAuthConfigWithTenant(tenantID string) (*azure.OAuthConfig, error) { + return azure.PublicCloud.OAuthConfigForTenant(tenantID) +} diff --git a/builder/azure/arm/authenticate_test.go b/builder/azure/arm/authenticate_test.go new file mode 100644 index 000000000..2e290bb3d --- /dev/null +++ b/builder/azure/arm/authenticate_test.go @@ -0,0 +1,39 @@ +package arm + +import ( + "github.com/Azure/go-autorest/autorest/azure" + "testing" +) + +// Behavior is the most important thing to assert for ServicePrincipalToken, but +// that cannot be done in a unit test because it involves network access. Instead, +// I assert the expected intertness of this class. +func TestNewAuthenticate(t *testing.T) { + testSubject := NewAuthenticate(azure.PublicCloud, "clientID", "clientString", "tenantID") + spn, err := testSubject.getServicePrincipalToken() + if err != nil { + t.Fatalf(err.Error()) + } + + if spn.Token.AccessToken != "" { + t.Errorf("spn.Token.AccessToken: expected=\"\", actual=%s", spn.Token.AccessToken) + } + if spn.Token.RefreshToken != "" { + t.Errorf("spn.Token.RefreshToken: expected=\"\", actual=%s", spn.Token.RefreshToken) + } + if spn.Token.ExpiresIn != "" { + t.Errorf("spn.Token.ExpiresIn: expected=\"\", actual=%s", spn.Token.ExpiresIn) + } + if spn.Token.ExpiresOn != "" { + t.Errorf("spn.Token.ExpiresOn: expected=\"\", actual=%s", spn.Token.ExpiresOn) + } + if spn.Token.NotBefore != "" { + t.Errorf("spn.Token.NotBefore: expected=\"\", actual=%s", spn.Token.NotBefore) + } + if spn.Token.Resource != "" { + t.Errorf("spn.Token.Resource: expected=\"\", actual=%s", spn.Token.Resource) + } + if spn.Token.Type != "" { + t.Errorf("spn.Token.Type: expected=\"\", actual=%s", spn.Token.Type) + } +} diff --git a/builder/azure/arm/azure_client.go b/builder/azure/arm/azure_client.go index 797b89df5..59df7f30f 100644 --- a/builder/azure/arm/azure_client.go +++ b/builder/azure/arm/azure_client.go @@ -1,45 +1,150 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm import ( + "encoding/json" + "fmt" + "math" + "net/http" + "os" + "strconv" + "github.com/Azure/azure-sdk-for-go/arm/compute" "github.com/Azure/azure-sdk-for-go/arm/network" "github.com/Azure/azure-sdk-for-go/arm/resources/resources" armStorage "github.com/Azure/azure-sdk-for-go/arm/storage" "github.com/Azure/azure-sdk-for-go/storage" - + "github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest/azure" + "github.com/mitchellh/packer/builder/azure/common" + "github.com/mitchellh/packer/version" +) + +const ( + EnvPackerLogAzureMaxLen = "PACKER_LOG_AZURE_MAXLEN" +) + +var ( + packerUserAgent = fmt.Sprintf(";packer/%s", version.FormattedVersion()) ) type AzureClient struct { - compute.VirtualMachinesClient - network.PublicIPAddressesClient - resources.GroupsClient - resources.DeploymentsClient storage.BlobStorageClient + resources.DeploymentsClient + resources.GroupsClient + network.PublicIPAddressesClient + compute.VirtualMachinesClient + common.VaultClient + armStorage.AccountsClient + + InspectorMaxLength int + Template *CaptureTemplate } -func NewAzureClient(subscriptionID string, resourceGroupName string, storageAccountName string, servicePrincipalToken *azure.ServicePrincipalToken) (*AzureClient, error) { +func getCaptureResponse(body string) *CaptureTemplate { + var operation CaptureOperation + err := json.Unmarshal([]byte(body), &operation) + if err != nil { + return nil + } + + if operation.Properties != nil && operation.Properties.Output != nil { + return operation.Properties.Output + } + + return nil +} + +// HACK(chrboum): This method is a hack. It was written to work around this issue +// (https://github.com/Azure/azure-sdk-for-go/issues/307) and to an extent this +// issue (https://github.com/Azure/azure-rest-api-specs/issues/188). +// +// Capturing a VM is a long running operation that requires polling. There are +// couple different forms of polling, and the end result of a poll operation is +// discarded by the SDK. It is expected that any discarded data can be re-fetched, +// so discarding it has minimal impact. Unfortunately, there is no way to re-fetch +// the template returned by a capture call that I am aware of. +// +// If the second issue were fixed the VM ID would be included when GET'ing a VM. The +// VM ID could be used to locate the captured VHD, and captured template. +// Unfortunately, the VM ID is not included so this method cannot be used either. +// +// This code captures the template and saves it to the client (the AzureClient type). +// It expects that the capture API is called only once, or rather you only care that the +// last call's value is important because subsequent requests are not persisted. There +// is no care given to multiple threads writing this value because for our use case +// it does not matter. +func templateCapture(client *AzureClient) autorest.RespondDecorator { + return func(r autorest.Responder) autorest.Responder { + return autorest.ResponderFunc(func(resp *http.Response) error { + body, bodyString := handleBody(resp.Body, math.MaxInt64) + resp.Body = body + + captureTemplate := getCaptureResponse(bodyString) + if captureTemplate != nil { + client.Template = captureTemplate + } + + return r.Respond(resp) + }) + } +} + +// WAITING(chrboum): I have logged https://github.com/Azure/azure-sdk-for-go/issues/311 to get this +// method included in the SDK. It has been accepted, and I'll cut over to the official way +// once it ships. +func byConcatDecorators(decorators ...autorest.RespondDecorator) autorest.RespondDecorator { + return func(r autorest.Responder) autorest.Responder { + return autorest.DecorateResponder(r, decorators...) + } +} + +func NewAzureClient(subscriptionID, resourceGroupName, storageAccountName string, + servicePrincipalToken, servicePrincipalTokenVault *azure.ServicePrincipalToken) (*AzureClient, error) { + var azureClient = &AzureClient{} + maxlen := getInspectorMaxLength() + azureClient.DeploymentsClient = resources.NewDeploymentsClient(subscriptionID) azureClient.DeploymentsClient.Authorizer = servicePrincipalToken + azureClient.DeploymentsClient.RequestInspector = withInspection(maxlen) + azureClient.DeploymentsClient.ResponseInspector = byInspecting(maxlen) + azureClient.DeploymentsClient.UserAgent += packerUserAgent azureClient.GroupsClient = resources.NewGroupsClient(subscriptionID) azureClient.GroupsClient.Authorizer = servicePrincipalToken + azureClient.GroupsClient.RequestInspector = withInspection(maxlen) + azureClient.GroupsClient.ResponseInspector = byInspecting(maxlen) + azureClient.GroupsClient.UserAgent += packerUserAgent azureClient.PublicIPAddressesClient = network.NewPublicIPAddressesClient(subscriptionID) azureClient.PublicIPAddressesClient.Authorizer = servicePrincipalToken + azureClient.PublicIPAddressesClient.RequestInspector = withInspection(maxlen) + azureClient.PublicIPAddressesClient.ResponseInspector = byInspecting(maxlen) + azureClient.PublicIPAddressesClient.UserAgent += packerUserAgent azureClient.VirtualMachinesClient = compute.NewVirtualMachinesClient(subscriptionID) azureClient.VirtualMachinesClient.Authorizer = servicePrincipalToken + azureClient.VirtualMachinesClient.RequestInspector = withInspection(maxlen) + azureClient.VirtualMachinesClient.ResponseInspector = byConcatDecorators(byInspecting(maxlen), templateCapture(azureClient)) + azureClient.VirtualMachinesClient.UserAgent += packerUserAgent - storageAccountsClient := armStorage.NewAccountsClient(subscriptionID) - storageAccountsClient.Authorizer = servicePrincipalToken + azureClient.AccountsClient = armStorage.NewAccountsClient(subscriptionID) + azureClient.AccountsClient.Authorizer = servicePrincipalToken + azureClient.AccountsClient.RequestInspector = withInspection(maxlen) + azureClient.AccountsClient.ResponseInspector = byInspecting(maxlen) + azureClient.AccountsClient.UserAgent += packerUserAgent - accountKeys, err := storageAccountsClient.ListKeys(resourceGroupName, storageAccountName) + azureClient.VaultClient = common.VaultClient{} + azureClient.VaultClient.Authorizer = servicePrincipalTokenVault + azureClient.VaultClient.RequestInspector = withInspection(maxlen) + azureClient.VaultClient.ResponseInspector = byInspecting(maxlen) + azureClient.VaultClient.UserAgent += packerUserAgent + + accountKeys, err := azureClient.AccountsClient.ListKeys(resourceGroupName, storageAccountName) if err != nil { return nil, err } @@ -52,3 +157,21 @@ func NewAzureClient(subscriptionID string, resourceGroupName string, storageAcco azureClient.BlobStorageClient = storageClient.GetBlobService() return azureClient, nil } + +func getInspectorMaxLength() int64 { + value, ok := os.LookupEnv(EnvPackerLogAzureMaxLen) + if !ok { + return math.MaxInt64 + } + + i, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return 0 + } + + if i < 0 { + return 0 + } + + return i +} diff --git a/builder/azure/arm/builder.go b/builder/azure/arm/builder.go index cc9e7463a..71dc8f822 100644 --- a/builder/azure/arm/builder.go +++ b/builder/azure/arm/builder.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm @@ -7,12 +7,15 @@ import ( "errors" "fmt" "log" + "time" + + packerAzureCommon "github.com/mitchellh/packer/builder/azure/common" + + "github.com/Azure/go-autorest/autorest/azure" "github.com/mitchellh/packer/builder/azure/common/constants" "github.com/mitchellh/packer/builder/azure/common/lin" - "github.com/Azure/go-autorest/autorest/azure" - "github.com/mitchellh/multistep" "github.com/mitchellh/packer/common" "github.com/mitchellh/packer/helper/communicator" @@ -27,6 +30,9 @@ type Builder struct { const ( DefaultPublicIPAddressName = "packerPublicIP" + DefaultSasBlobContainer = "system/Microsoft.Compute" + DefaultSasBlobPermission = "r" + DefaultSecretName = "packerKeyVaultSecret" ) func (b *Builder) Prepare(raws ...interface{}) ([]string, error) { @@ -38,10 +44,8 @@ func (b *Builder) Prepare(raws ...interface{}) ([]string, error) { b.config = c b.stateBag = new(multistep.BasicStateBag) - err := b.configureStateBag(b.stateBag) - if err != nil { - return nil, err - } + b.configureStateBag(b.stateBag) + b.setTemplateParameters(b.stateBag) return warnings, errs } @@ -52,33 +56,81 @@ func (b *Builder) Run(ui packer.Ui, hook packer.Hook, cache packer.Cache) (packe b.stateBag.Put("hook", hook) b.stateBag.Put(constants.Ui, ui) - servicePrincipalToken, err := b.createServicePrincipalToken() + spnCloud, spnKeyVault, err := b.getServicePrincipalTokens(ui.Say) if err != nil { return nil, err } ui.Message("Creating Azure Resource Manager (ARM) client ...") - azureClient, err := NewAzureClient(b.config.SubscriptionID, b.config.ResourceGroupName, b.config.StorageAccount, servicePrincipalToken) + azureClient, err := NewAzureClient( + b.config.SubscriptionID, + b.config.ResourceGroupName, + b.config.StorageAccount, + spnCloud, + spnKeyVault) + if err != nil { return nil, err } - steps := []multistep.Step{ - NewStepCreateResourceGroup(azureClient, ui), - NewStepValidateTemplate(azureClient, ui), - NewStepDeployTemplate(azureClient, ui), - NewStepGetIPAddress(azureClient, ui), - &communicator.StepConnectSSH{ - Config: &b.config.Comm, - Host: lin.SSHHost, - SSHConfig: lin.SSHConfig(b.config.UserName), - }, - &common.StepProvision{}, - NewStepGetOSDisk(azureClient, ui), - NewStepPowerOffCompute(azureClient, ui), - NewStepCaptureImage(azureClient, ui), - NewStepDeleteResourceGroup(azureClient, ui), - NewStepDeleteOSDisk(azureClient, ui), + b.config.storageAccountBlobEndpoint, err = b.getBlobEndpoint(azureClient, b.config.ResourceGroupName, b.config.StorageAccount) + if err != nil { + return nil, err + } + + b.setTemplateParameters(b.stateBag) + + var steps []multistep.Step + + if b.config.OSType == constants.Target_Linux { + steps = []multistep.Step{ + NewStepCreateResourceGroup(azureClient, ui), + NewStepValidateTemplate(azureClient, ui, Linux), + NewStepDeployTemplate(azureClient, ui, Linux), + NewStepGetIPAddress(azureClient, ui), + &communicator.StepConnectSSH{ + Config: &b.config.Comm, + Host: lin.SSHHost, + SSHConfig: lin.SSHConfig(b.config.UserName), + }, + &common.StepProvision{}, + NewStepGetOSDisk(azureClient, ui), + NewStepPowerOffCompute(azureClient, ui), + NewStepCaptureImage(azureClient, ui), + NewStepDeleteResourceGroup(azureClient, ui), + NewStepDeleteOSDisk(azureClient, ui), + } + } else if b.config.OSType == constants.Target_Windows { + steps = []multistep.Step{ + NewStepCreateResourceGroup(azureClient, ui), + NewStepValidateTemplate(azureClient, ui, KeyVault), + NewStepDeployTemplate(azureClient, ui, KeyVault), + NewStepGetCertificate(azureClient, ui), + NewStepSetCertificate(b.config, ui), + NewStepValidateTemplate(azureClient, ui, Windows), + NewStepDeployTemplate(azureClient, ui, Windows), + NewStepGetIPAddress(azureClient, ui), + &communicator.StepConnectWinRM{ + Config: &b.config.Comm, + Host: func(stateBag multistep.StateBag) (string, error) { + return stateBag.Get(constants.SSHHost).(string), nil + }, + WinRMConfig: func(multistep.StateBag) (*communicator.WinRMConfig, error) { + return &communicator.WinRMConfig{ + Username: b.config.UserName, + Password: b.config.tmpAdminPassword, + }, nil + }, + }, + &common.StepProvision{}, + NewStepGetOSDisk(azureClient, ui), + NewStepPowerOffCompute(azureClient, ui), + NewStepCaptureImage(azureClient, ui), + NewStepDeleteResourceGroup(azureClient, ui), + NewStepDeleteOSDisk(azureClient, ui), + } + } else { + return nil, fmt.Errorf("Builder does not support the os_type '%s'", b.config.OSType) } if b.config.PackerDebug { @@ -103,7 +155,17 @@ func (b *Builder) Run(ui packer.Ui, hook packer.Hook, cache packer.Cache) (packe return nil, errors.New("Build was halted.") } - return &artifact{}, nil + if template, ok := b.stateBag.GetOk(constants.ArmCaptureTemplate); ok { + return NewArtifact( + template.(*CaptureTemplate), + func(name string) string { + month := time.Now().AddDate(0, 1, 0).UTC() + sasUrl, _ := azureClient.BlobStorageClient.GetBlobSASURI(DefaultSasBlobContainer, name, month, DefaultSasBlobPermission) + return sasUrl + }) + } + + return &Artifact{}, nil } func (b *Builder) Cancel() { @@ -126,33 +188,57 @@ func (b *Builder) createRunner(steps *[]multistep.Step, ui packer.Ui) multistep. } } -func (b *Builder) configureStateBag(stateBag multistep.StateBag) error { +func (b *Builder) getBlobEndpoint(client *AzureClient, resourceGroupName string, storageAccountName string) (string, error) { + account, err := client.AccountsClient.GetProperties(resourceGroupName, storageAccountName) + if err != nil { + return "", err + } + + return *account.Properties.PrimaryEndpoints.Blob, nil +} + +func (b *Builder) configureStateBag(stateBag multistep.StateBag) { stateBag.Put(constants.AuthorizedKey, b.config.sshAuthorizedKey) stateBag.Put(constants.PrivateKey, b.config.sshPrivateKey) stateBag.Put(constants.ArmComputeName, b.config.tmpComputeName) stateBag.Put(constants.ArmDeploymentName, b.config.tmpDeploymentName) + stateBag.Put(constants.ArmKeyVaultName, b.config.tmpKeyVaultName) stateBag.Put(constants.ArmLocation, b.config.Location) + stateBag.Put(constants.ArmPublicIPAddressName, DefaultPublicIPAddressName) stateBag.Put(constants.ArmResourceGroupName, b.config.tmpResourceGroupName) + stateBag.Put(constants.ArmStorageAccountName, b.config.StorageAccount) +} + +func (b *Builder) setTemplateParameters(stateBag multistep.StateBag) { stateBag.Put(constants.ArmTemplateParameters, b.config.toTemplateParameters()) stateBag.Put(constants.ArmVirtualMachineCaptureParameters, b.config.toVirtualMachineCaptureParameters()) - - stateBag.Put(constants.ArmPublicIPAddressName, DefaultPublicIPAddressName) - - return nil } -func (b *Builder) createServicePrincipalToken() (*azure.ServicePrincipalToken, error) { - oauthConfig, err := azure.PublicCloud.OAuthConfigForTenant(b.config.TenantID) - if err != nil { - return nil, err +func (b *Builder) getServicePrincipalTokens(say func(string)) (*azure.ServicePrincipalToken, *azure.ServicePrincipalToken, error) { + var servicePrincipalToken *azure.ServicePrincipalToken + var servicePrincipalTokenVault *azure.ServicePrincipalToken + + var err error + + if b.config.useDeviceLogin { + servicePrincipalToken, err = packerAzureCommon.Authenticate(*b.config.cloudEnvironment, b.config.SubscriptionID, say) + if err != nil { + return nil, nil, err + } + } else { + auth := NewAuthenticate(*b.config.cloudEnvironment, b.config.ClientID, b.config.ClientSecret, b.config.TenantID) + + servicePrincipalToken, err = auth.getServicePrincipalToken() + if err != nil { + return nil, nil, err + } + + servicePrincipalTokenVault, err = auth.getServicePrincipalTokenWithResource(packerAzureCommon.AzureVaultScope) + if err != nil { + return nil, nil, err + } } - spt, err := azure.NewServicePrincipalToken( - *oauthConfig, - b.config.ClientID, - b.config.ClientSecret, - azure.PublicCloud.ResourceManagerEndpoint) - - return spt, err + return servicePrincipalToken, servicePrincipalTokenVault, nil } diff --git a/builder/azure/arm/builder_test.go b/builder/azure/arm/builder_test.go index 1485265df..bbf173f7d 100644 --- a/builder/azure/arm/builder_test.go +++ b/builder/azure/arm/builder_test.go @@ -1,17 +1,19 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm import ( - "testing" - "github.com/mitchellh/packer/builder/azure/common/constants" + "testing" ) func TestStateBagShouldBePopulatedExpectedValues(t *testing.T) { var testSubject = &Builder{} - testSubject.Prepare(getArmBuilderConfiguration(), getPackerConfiguration()) + _, err := testSubject.Prepare(getArmBuilderConfiguration(), getPackerConfiguration()) + if err != nil { + t.Fatalf("failed to prepare: %s", err) + } var expectedStateBagKeys = []string{ constants.AuthorizedKey, @@ -21,6 +23,7 @@ func TestStateBagShouldBePopulatedExpectedValues(t *testing.T) { constants.ArmDeploymentName, constants.ArmLocation, constants.ArmResourceGroupName, + constants.ArmStorageAccountName, constants.ArmTemplateParameters, constants.ArmVirtualMachineCaptureParameters, constants.ArmPublicIPAddressName, diff --git a/builder/azure/arm/capture_template.go b/builder/azure/arm/capture_template.go new file mode 100644 index 000000000..9e136130d --- /dev/null +++ b/builder/azure/arm/capture_template.go @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. + +package arm + +type CaptureTemplateParameter struct { + Type string `json:"type"` + DefaultValue string `json:"defaultValue,omitempty"` +} + +type CaptureHardwareProfile struct { + VMSize string `json:"vmSize"` +} + +type CaptureUri struct { + Uri string `json:"uri"` +} + +type CaptureDisk struct { + OSType string `json:"osType"` + Name string `json:"name"` + Image CaptureUri `json:"image"` + Vhd CaptureUri `json:"vhd"` + CreateOption string `json:"createOption"` + Caching string `json:"caching"` +} + +type CaptureStorageProfile struct { + OSDisk CaptureDisk `json:"osDisk"` +} + +type CaptureOSProfile struct { + ComputerName string `json:"computerName"` + AdminUsername string `json:"adminUsername"` + AdminPassword string `json:"adminPassword"` +} + +type CaptureNetworkInterface struct { + Id string `json:"id"` +} + +type CaptureNetworkProfile struct { + NetworkInterfaces []CaptureNetworkInterface `json:"networkInterfaces"` +} + +type CaptureBootDiagnostics struct { + Enabled bool `json:"enabled"` +} + +type CaptureDiagnosticProfile struct { + BootDiagnostics CaptureBootDiagnostics `json:"bootDiagnostics"` +} + +type CaptureProperties struct { + HardwareProfile CaptureHardwareProfile `json:"hardwareProfile"` + StorageProfile CaptureStorageProfile `json:"storageProfile"` + OSProfile CaptureOSProfile `json:"osProfile"` + NetworkProfile CaptureNetworkProfile `json:"networkProfile"` + DiagnosticsProfile CaptureDiagnosticProfile `json:"diagnosticsProfile"` + ProvisioningState int `json:"provisioningState"` +} + +type CaptureResources struct { + ApiVersion string `json:"apiVersion"` + Name string `json:"name"` + Type string `json:"type"` + Location string `json:"location"` + Properties CaptureProperties `json:"properties"` +} + +type CaptureTemplate struct { + Schema string `json:"$schema"` + ContentVersion string `json:"contentVersion"` + Parameters map[string]CaptureTemplateParameter `json:"parameters"` + Resources []CaptureResources `json:"resources"` +} + +type CaptureOperationProperties struct { + Output *CaptureTemplate `json:"output"` +} + +type CaptureOperation struct { + OperationId string `json:"operationId"` + Status string `json:"status"` + Properties *CaptureOperationProperties `json:"properties"` +} diff --git a/builder/azure/arm/capture_template_test.go b/builder/azure/arm/capture_template_test.go new file mode 100644 index 000000000..2145e0d1d --- /dev/null +++ b/builder/azure/arm/capture_template_test.go @@ -0,0 +1,213 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. + +package arm + +import ( + "encoding/json" + "testing" +) + +var captureTemplate01 = `{ + "operationId": "ac1c7c38-a591-41b3-89bd-ea39fceace1b", + "status": "Succeeded", + "startTime": "2016-04-04T21:07:25.2900874+00:00", + "endTime": "2016-04-04T21:07:26.4776321+00:00", + "properties": { + "output": { + "$schema": "http://schema.management.azure.com/schemas/2014-04-01-preview/VM_IP.json", + "contentVersion": "1.0.0.0", + "parameters": { + "vmName": { + "type": "string" + }, + "vmSize": { + "type": "string", + "defaultValue": "Standard_A2" + }, + "adminUserName": { + "type": "string" + }, + "adminPassword": { + "type": "securestring" + }, + "networkInterfaceId": { + "type": "string" + } + }, + "resources": [ + { + "apiVersion": "2015-06-15", + "properties": { + "hardwareProfile": { + "vmSize": "[parameters('vmSize')]" + }, + "storageProfile": { + "osDisk": { + "osType": "Linux", + "name": "packer-osDisk.32118633-6dc9-449f-83b6-a7d2983bec14.vhd", + "createOption": "FromImage", + "image": { + "uri": "http://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/packer-osDisk.32118633-6dc9-449f-83b6-a7d2983bec14.vhd" + }, + "vhd": { + "uri": "http://storage.blob.core.windows.net/vmcontainerce1a1b75-f480-47cb-8e6e-55142e4a5f68/osDisk.ce1a1b75-f480-47cb-8e6e-55142e4a5f68.vhd" + }, + "caching": "ReadWrite" + } + }, + "osProfile": { + "computerName": "[parameters('vmName')]", + "adminUsername": "[parameters('adminUsername')]", + "adminPassword": "[parameters('adminPassword')]" + }, + "networkProfile": { + "networkInterfaces": [ + { + "id": "[parameters('networkInterfaceId')]" + } + ] + }, + "diagnosticsProfile": { + "bootDiagnostics": { + "enabled": false + } + }, + "provisioningState": 0 + }, + "name": "[parameters('vmName')]", + "type": "Microsoft.Compute/virtualMachines", + "location": "southcentralus" + } + ] + } + } +}` + +var captureTemplate02 = `{ + "operationId": "ac1c7c38-a591-41b3-89bd-ea39fceace1b", + "status": "Succeeded", + "startTime": "2016-04-04T21:07:25.2900874+00:00", + "endTime": "2016-04-04T21:07:26.4776321+00:00" +}` + +func TestCaptureParseJson(t *testing.T) { + var operation CaptureOperation + err := json.Unmarshal([]byte(captureTemplate01), &operation) + if err != nil { + t.Fatalf("failed to the sample capture operation: %s", err) + } + + testSubject := operation.Properties.Output + if testSubject.Schema != "http://schema.management.azure.com/schemas/2014-04-01-preview/VM_IP.json" { + t.Errorf("Schema's value was unexpected: %s", testSubject.Schema) + } + if testSubject.ContentVersion != "1.0.0.0" { + t.Errorf("ContentVersion's value was unexpected: %s", testSubject.ContentVersion) + } + + // == Parameters ==================================== + if len(testSubject.Parameters) != 5 { + t.Fatalf("expected parameters to have 5 keys, but got %d", len(testSubject.Parameters)) + } + if _, ok := testSubject.Parameters["vmName"]; !ok { + t.Errorf("Parameters['vmName'] was an expected parameters, but it did not exist") + } + if testSubject.Parameters["vmName"].Type != "string" { + t.Errorf("Parameters['vmName'].Type == 'string', but got '%s'", testSubject.Parameters["vmName"].Type) + } + if _, ok := testSubject.Parameters["vmSize"]; !ok { + t.Errorf("Parameters['vmSize'] was an expected parameters, but it did not exist") + } + if testSubject.Parameters["vmSize"].Type != "string" { + t.Errorf("Parameters['vmSize'].Type == 'string', but got '%s'", testSubject.Parameters["vmSize"]) + } + if testSubject.Parameters["vmSize"].DefaultValue != "Standard_A2" { + t.Errorf("Parameters['vmSize'].DefaultValue == 'string', but got '%s'", testSubject.Parameters["vmSize"].DefaultValue) + } + + // == Resources ===================================== + if len(testSubject.Resources) != 1 { + t.Fatalf("expected resources to have length 1, but got %d", len(testSubject.Resources)) + } + if testSubject.Resources[0].Name != "[parameters('vmName')]" { + t.Errorf("Resources[0].Name's value was unexpected: %s", testSubject.Resources[0].Name) + } + if testSubject.Resources[0].Type != "Microsoft.Compute/virtualMachines" { + t.Errorf("Resources[0].Type's value was unexpected: %s", testSubject.Resources[0].Type) + } + if testSubject.Resources[0].Location != "southcentralus" { + t.Errorf("Resources[0].Location's value was unexpected: %s", testSubject.Resources[0].Location) + } + + // == Resources/Properties ===================================== + if testSubject.Resources[0].Properties.ProvisioningState != 0 { + t.Errorf("Resources[0].Properties.ProvisioningState's value was unexpected: %d", testSubject.Resources[0].Properties.ProvisioningState) + } + + // == Resources/Properties/HardwareProfile ====================== + hardwareProfile := testSubject.Resources[0].Properties.HardwareProfile + if hardwareProfile.VMSize != "[parameters('vmSize')]" { + t.Errorf("Resources[0].Properties.HardwareProfile.VMSize's value was unexpected: %s", hardwareProfile.VMSize) + } + + // == Resources/Properties/StorageProfile/OSDisk ================ + osDisk := testSubject.Resources[0].Properties.StorageProfile.OSDisk + if osDisk.OSType != "Linux" { + t.Errorf("Resources[0].Properties.StorageProfile.OSDisk.OSDisk's value was unexpected: %s", osDisk.OSType) + } + if osDisk.Name != "packer-osDisk.32118633-6dc9-449f-83b6-a7d2983bec14.vhd" { + t.Errorf("Resources[0].Properties.StorageProfile.OSDisk.Name's value was unexpected: %s", osDisk.Name) + } + if osDisk.CreateOption != "FromImage" { + t.Errorf("Resources[0].Properties.StorageProfile.OSDisk.CreateOption's value was unexpected: %s", osDisk.CreateOption) + } + if osDisk.Image.Uri != "http://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/packer-osDisk.32118633-6dc9-449f-83b6-a7d2983bec14.vhd" { + t.Errorf("Resources[0].Properties.StorageProfile.OSDisk.Image.Uri's value was unexpected: %s", osDisk.Image.Uri) + } + if osDisk.Vhd.Uri != "http://storage.blob.core.windows.net/vmcontainerce1a1b75-f480-47cb-8e6e-55142e4a5f68/osDisk.ce1a1b75-f480-47cb-8e6e-55142e4a5f68.vhd" { + t.Errorf("Resources[0].Properties.StorageProfile.OSDisk.Vhd.Uri's value was unexpected: %s", osDisk.Vhd.Uri) + } + if osDisk.Caching != "ReadWrite" { + t.Errorf("Resources[0].Properties.StorageProfile.OSDisk.Caching's value was unexpected: %s", osDisk.Caching) + } + + // == Resources/Properties/OSProfile ============================ + osProfile := testSubject.Resources[0].Properties.OSProfile + if osProfile.AdminPassword != "[parameters('adminPassword')]" { + t.Errorf("Resources[0].Properties.OSProfile.AdminPassword's value was unexpected: %s", osProfile.AdminPassword) + } + if osProfile.AdminUsername != "[parameters('adminUsername')]" { + t.Errorf("Resources[0].Properties.OSProfile.AdminUsername's value was unexpected: %s", osProfile.AdminUsername) + } + if osProfile.ComputerName != "[parameters('vmName')]" { + t.Errorf("Resources[0].Properties.OSProfile.ComputerName's value was unexpected: %s", osProfile.ComputerName) + } + + // == Resources/Properties/NetworkProfile ======================= + networkProfile := testSubject.Resources[0].Properties.NetworkProfile + if len(networkProfile.NetworkInterfaces) != 1 { + t.Errorf("Count of Resources[0].Properties.NetworkProfile.NetworkInterfaces was expected to be 1, but go %d", len(networkProfile.NetworkInterfaces)) + } + if networkProfile.NetworkInterfaces[0].Id != "[parameters('networkInterfaceId')]" { + t.Errorf("Resources[0].Properties.NetworkProfile.NetworkInterfaces[0].Id's value was unexpected: %s", networkProfile.NetworkInterfaces[0].Id) + } + + // == Resources/Properties/DiagnosticsProfile =================== + diagnosticsProfile := testSubject.Resources[0].Properties.DiagnosticsProfile + if diagnosticsProfile.BootDiagnostics.Enabled != false { + t.Errorf("Resources[0].Properties.DiagnosticsProfile.BootDiagnostics.Enabled's value was unexpected: %t", diagnosticsProfile.BootDiagnostics.Enabled) + } +} + +func TestCaptureEmptyOperationJson(t *testing.T) { + var operation CaptureOperation + err := json.Unmarshal([]byte(captureTemplate02), &operation) + if err != nil { + t.Fatalf("failed to the sample capture operation: %s", err) + } + + if operation.Properties != nil { + t.Errorf("JSON contained no properties, but value was not nil: %+v", operation.Properties) + } +} diff --git a/builder/azure/arm/config.go b/builder/azure/arm/config.go index e371c3aa5..ae2fdea4c 100644 --- a/builder/azure/arm/config.go +++ b/builder/azure/arm/config.go @@ -1,29 +1,43 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" "encoding/base64" + "encoding/json" "fmt" "io/ioutil" + "math/big" + "net/http" "time" - "golang.org/x/crypto/ssh" - "github.com/Azure/azure-sdk-for-go/arm/compute" "github.com/Azure/go-autorest/autorest/to" + "github.com/Azure/go-ntlmssp" + "github.com/mitchellh/packer/builder/azure/common/constants" + "github.com/mitchellh/packer/builder/azure/pkcs12" "github.com/mitchellh/packer/common" "github.com/mitchellh/packer/helper/communicator" "github.com/mitchellh/packer/helper/config" "github.com/mitchellh/packer/packer" "github.com/mitchellh/packer/template/interpolate" + + "github.com/Azure/go-autorest/autorest/azure" + "golang.org/x/crypto/ssh" + "strings" ) const ( - DefaultUserName = "packer" - DefaultVMSize = "Standard_A1" + DefaultCloudEnvironmentName = "Public" + DefaultImageVersion = "latest" + DefaultUserName = "packer" + DefaultVMSize = "Standard_A1" ) type Config struct { @@ -32,6 +46,7 @@ type Config struct { // Authentication via OAUTH ClientID string `mapstructure:"client_id"` ClientSecret string `mapstructure:"client_secret"` + ObjectID string `mapstructure:"object_id"` TenantID string `mapstructure:"tenant_id"` SubscriptionID string `mapstructure:"subscription_id"` @@ -43,46 +58,81 @@ type Config struct { ImagePublisher string `mapstructure:"image_publisher"` ImageOffer string `mapstructure:"image_offer"` ImageSku string `mapstructure:"image_sku"` + ImageVersion string `mapstructure:"image_version"` Location string `mapstructure:"location"` VMSize string `mapstructure:"vm_size"` // Deployment - ResourceGroupName string `mapstructure:"resource_group_name"` - StorageAccount string `mapstructure:"storage_account"` + ResourceGroupName string `mapstructure:"resource_group_name"` + StorageAccount string `mapstructure:"storage_account"` + storageAccountBlobEndpoint string + CloudEnvironmentName string `mapstructure:"cloud_environment_name"` + cloudEnvironment *azure.Environment + + // OS + OSType string `mapstructure:"os_type"` // Runtime Values - UserName string - Password string - tmpAdminPassword string - tmpResourceGroupName string - tmpComputeName string - tmpDeploymentName string - tmpOSDiskName string + UserName string + Password string + tmpAdminPassword string + tmpCertificatePassword string + tmpResourceGroupName string + tmpComputeName string + tmpDeploymentName string + tmpKeyVaultName string + tmpOSDiskName string + tmpWinRMCertificateUrl string + + useDeviceLogin bool // Authentication with the VM via SSH sshAuthorizedKey string sshPrivateKey string + // Authentication with the VM via WinRM + winrmCertificate string + Comm communicator.Config `mapstructure:",squash"` ctx *interpolate.Context } +type keyVaultCertificate struct { + Data string `json:"data"` + DataType string `json:"dataType"` + Password string `json:"password,omitempty"` +} + // If we ever feel the need to support more templates consider moving this // method to its own factory class. func (c *Config) toTemplateParameters() *TemplateParameters { - return &TemplateParameters{ - AdminUsername: &TemplateParameter{c.UserName}, - AdminPassword: &TemplateParameter{c.Password}, - DnsNameForPublicIP: &TemplateParameter{c.tmpComputeName}, - ImageOffer: &TemplateParameter{c.ImageOffer}, - ImagePublisher: &TemplateParameter{c.ImagePublisher}, - ImageSku: &TemplateParameter{c.ImageSku}, - OSDiskName: &TemplateParameter{c.tmpOSDiskName}, - SshAuthorizedKey: &TemplateParameter{c.sshAuthorizedKey}, - StorageAccountName: &TemplateParameter{c.StorageAccount}, - VMSize: &TemplateParameter{c.VMSize}, - VMName: &TemplateParameter{c.tmpComputeName}, + templateParameters := &TemplateParameters{ + AdminUsername: &TemplateParameter{c.UserName}, + AdminPassword: &TemplateParameter{c.Password}, + DnsNameForPublicIP: &TemplateParameter{c.tmpComputeName}, + ImageOffer: &TemplateParameter{c.ImageOffer}, + ImagePublisher: &TemplateParameter{c.ImagePublisher}, + ImageSku: &TemplateParameter{c.ImageSku}, + ImageVersion: &TemplateParameter{c.ImageVersion}, + OSDiskName: &TemplateParameter{c.tmpOSDiskName}, + StorageAccountBlobEndpoint: &TemplateParameter{c.storageAccountBlobEndpoint}, + VMSize: &TemplateParameter{c.VMSize}, + VMName: &TemplateParameter{c.tmpComputeName}, } + + switch c.OSType { + case constants.Target_Linux: + templateParameters.SshAuthorizedKey = &TemplateParameter{c.sshAuthorizedKey} + case constants.Target_Windows: + templateParameters.TenantId = &TemplateParameter{c.TenantID} + templateParameters.ObjectId = &TemplateParameter{c.ObjectID} + + templateParameters.KeyVaultName = &TemplateParameter{c.tmpKeyVaultName} + templateParameters.KeyVaultSecretValue = &TemplateParameter{c.winrmCertificate} + templateParameters.WinRMCertificateUrl = &TemplateParameter{c.tmpWinRMCertificateUrl} + } + + return templateParameters } func (c *Config) toVirtualMachineCaptureParameters() *compute.VirtualMachineCaptureParameters { @@ -93,6 +143,66 @@ func (c *Config) toVirtualMachineCaptureParameters() *compute.VirtualMachineCapt } } +func (c *Config) createCertificate() (string, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + err := fmt.Errorf("Failed to Generate Private Key: %s", err) + return "", err + } + + host := fmt.Sprintf("%s.cloudapp.net", c.tmpComputeName) + notBefore := time.Now() + notAfter := notBefore.Add(365 * 24 * time.Hour) + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + err := fmt.Errorf("Failed to Generate Serial Number: %v", err) + return "", err + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Issuer: pkix.Name{ + CommonName: host, + }, + Subject: pkix.Name{ + CommonName: host, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + if err != nil { + err = fmt.Errorf("Failed to Create Certificate: %s", err) + return "", err + } + + pfxBytes, err := pkcs12.Encode(derBytes, privateKey, c.tmpCertificatePassword) + if err != nil { + err = fmt.Errorf("Failed to encode certificate as PFX: %s", err) + return "", err + } + + keyVaultDescription := keyVaultCertificate{ + Data: base64.StdEncoding.EncodeToString(pfxBytes), + DataType: "pfx", + Password: c.tmpCertificatePassword, + } + + bytes, err := json.Marshal(keyVaultDescription) + if err != nil { + err = fmt.Errorf("Failed to marshal key vault description: %s", err) + return "", err + } + + return base64.StdEncoding.EncodeToString(bytes), nil +} + func newConfig(raws ...interface{}) (*Config, []string, error) { var c Config @@ -108,12 +218,21 @@ func newConfig(raws ...interface{}) (*Config, []string, error) { provideDefaultValues(&c) setRuntimeValues(&c) setUserNamePassword(&c) + err = setCloudEnvironment(&c) + if err != nil { + return nil, nil, err + } err = setSshValues(&c) if err != nil { return nil, nil, err } + err = setWinRMCertificate(&c) + if err != nil { + return nil, nil, err + } + var errs *packer.MultiError errs = packer.MultiErrorAppend(errs, c.Comm.Prepare(c.ctx)...) @@ -157,18 +276,30 @@ func setSshValues(c *Config) error { c.sshPrivateKey = sshKeyPair.PrivateKey() } + c.Comm.WinRMTransportDecorator = func(t *http.Transport) http.RoundTripper { + return &ntlmssp.Negotiator{RoundTripper: t} + } + return nil } +func setWinRMCertificate(c *Config) error { + cert, err := c.createCertificate() + c.winrmCertificate = cert + + return err +} + func setRuntimeValues(c *Config) { var tempName = NewTempName() c.tmpAdminPassword = tempName.AdminPassword + c.tmpCertificatePassword = tempName.CertificatePassword c.tmpComputeName = tempName.ComputeName c.tmpDeploymentName = tempName.DeploymentName - // c.tmpResourceGroupName = c.ResourceGroupName c.tmpResourceGroupName = tempName.ResourceGroupName c.tmpOSDiskName = tempName.OSDiskName + c.tmpKeyVaultName = tempName.KeyVaultName } func setUserNamePassword(c *Config) { @@ -185,30 +316,77 @@ func setUserNamePassword(c *Config) { } } +func setCloudEnvironment(c *Config) error { + name := strings.ToUpper(c.CloudEnvironmentName) + switch name { + case "CHINA", "CHINACLOUD", "AZURECHINACLOUD": + c.cloudEnvironment = &azure.ChinaCloud + case "PUBLIC", "PUBLICCLOUD", "AZUREPUBLICCLOUD": + c.cloudEnvironment = &azure.PublicCloud + case "USGOVERNMENT", "USGOVERNMENTCLOUD", "AZUREUSGOVERNMENTCLOUD": + c.cloudEnvironment = &azure.USGovernmentCloud + default: + return fmt.Errorf("There is no cloud envionment matching the name '%s'!", c.CloudEnvironmentName) + } + + return nil +} + func provideDefaultValues(c *Config) { if c.VMSize == "" { c.VMSize = DefaultVMSize } + + if c.ImageVersion == "" { + c.ImageVersion = DefaultImageVersion + } + + if c.CloudEnvironmentName == "" { + c.CloudEnvironmentName = DefaultCloudEnvironmentName + } } func assertRequiredParametersSet(c *Config, errs *packer.MultiError) { ///////////////////////////////////////////// // Authentication via OAUTH - if c.ClientID == "" { - errs = packer.MultiErrorAppend(errs, fmt.Errorf("A client_id must be specified")) + // Check if device login is being asked for, and is allowed. + // + // Device login is enabled if the user only defines SubscriptionID and not + // ClientID, ClientSecret, and TenantID. + // + // Device login is not enabled for Windows because the WinRM certificate is + // readable by the ObjectID of the App. There may be another way to handle + // this case, but I am not currently aware of it - send feedback. + isUseDeviceLogin := func(c *Config) bool { + if c.OSType == constants.Target_Windows { + return false + } + + return c.SubscriptionID != "" && + c.ClientID == "" && + c.ClientSecret == "" && + c.TenantID == "" } - if c.ClientSecret == "" { - errs = packer.MultiErrorAppend(errs, fmt.Errorf("A client_secret must be specified")) - } + if isUseDeviceLogin(c) { + c.useDeviceLogin = true + } else { + if c.ClientID == "" { + errs = packer.MultiErrorAppend(errs, fmt.Errorf("A client_id must be specified")) + } - if c.TenantID == "" { - errs = packer.MultiErrorAppend(errs, fmt.Errorf("A tenant_id must be specified")) - } + if c.ClientSecret == "" { + errs = packer.MultiErrorAppend(errs, fmt.Errorf("A client_secret must be specified")) + } - if c.SubscriptionID == "" { - errs = packer.MultiErrorAppend(errs, fmt.Errorf("A subscription_id must be specified")) + if c.TenantID == "" { + errs = packer.MultiErrorAppend(errs, fmt.Errorf("A tenant_id must be specified")) + } + + if c.SubscriptionID == "" { + errs = packer.MultiErrorAppend(errs, fmt.Errorf("A subscription_id must be specified")) + } } ///////////////////////////////////////////// @@ -223,7 +401,6 @@ func assertRequiredParametersSet(c *Config, errs *packer.MultiError) { ///////////////////////////////////////////// // Compute - if c.ImagePublisher == "" { errs = packer.MultiErrorAppend(errs, fmt.Errorf("A image_publisher must be specified")) } @@ -242,8 +419,13 @@ func assertRequiredParametersSet(c *Config, errs *packer.MultiError) { ///////////////////////////////////////////// // Deployment - if c.StorageAccount == "" { errs = packer.MultiErrorAppend(errs, fmt.Errorf("A storage_account must be specified")) } + + ///////////////////////////////////////////// + // OS + if c.OSType != constants.Target_Linux && c.OSType != constants.Target_Windows { + errs = packer.MultiErrorAppend(errs, fmt.Errorf("An os_type must be specified")) + } } diff --git a/builder/azure/arm/config_test.go b/builder/azure/arm/config_test.go index dc183b8fa..e686bd60b 100644 --- a/builder/azure/arm/config_test.go +++ b/builder/azure/arm/config_test.go @@ -1,14 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm import ( - "bytes" "encoding/json" "fmt" + "strings" "testing" "time" + + "github.com/mitchellh/packer/builder/azure/common/constants" + "github.com/mitchellh/packer/packer" ) // List of configuration parameters that are required by the ARM builder. @@ -21,6 +24,7 @@ var requiredConfigValues = []string{ "image_publisher", "image_sku", "location", + "os_type", "storage_account", "subscription_id", "tenant_id", @@ -41,21 +45,23 @@ func TestConfigShouldProvideReasonableDefaultValues(t *testing.T) { if c.VMSize == "" { t.Errorf("Expected 'VMSize' to be populated, but it was empty!") } + + if c.ObjectID != "" { + t.Errorf("Expected 'ObjectID' to be nil, but it was '%s'!", c.ObjectID) + } } func TestConfigShouldBeAbleToOverrideDefaultedValues(t *testing.T) { - builderValues := make(map[string]string) - - // Populate the dictionary with all of the required values. - for _, v := range requiredConfigValues { - builderValues[v] = "--some-value--" - } - + builderValues := getArmBuilderConfiguration() builderValues["ssh_password"] = "override_password" builderValues["ssh_username"] = "override_username" builderValues["vm_size"] = "override_vm_size" - c, _, _ := newConfig(getArmBuilderConfigurationFromMap(builderValues), getPackerConfiguration()) + c, _, err := newConfig(builderValues, getPackerConfiguration()) + + if err != nil { + t.Fatalf("newConfig failed: %s", err) + } if c.Password != "override_password" { t.Errorf("Expected 'Password' to be set to 'override_password', but found '%s'!", c.Password) @@ -74,7 +80,7 @@ func TestConfigShouldBeAbleToOverrideDefaultedValues(t *testing.T) { } if c.VMSize != "override_vm_size" { - t.Errorf("Expected 'vm_size' to be set to 'override_username', but found '%s'!", c.VMSize) + t.Errorf("Expected 'vm_size' to be set to 'override_vm_size', but found '%s'!", c.VMSize) } } @@ -86,16 +92,77 @@ func TestConfigShouldDefaultVMSizeToStandardA1(t *testing.T) { } } -func TestUserShouldProvideRequiredValues(t *testing.T) { - builderValues := make(map[string]string) +func TestConfigShouldDefaultImageVersionToLatest(t *testing.T) { + c, _, _ := newConfig(getArmBuilderConfiguration(), getPackerConfiguration()) - // Populate the dictionary with all of the required values. - for _, v := range requiredConfigValues { - builderValues[v] = "--some-value--" + if c.ImageVersion != "latest" { + t.Errorf("Expected 'ImageVersion' to default to 'latest', but got '%s'.", c.ImageVersion) + } +} + +func TestConfigShouldDefaultToPublicCloud(t *testing.T) { + c, _, _ := newConfig(getArmBuilderConfiguration(), getPackerConfiguration()) + + if c.CloudEnvironmentName != "Public" { + t.Errorf("Expected 'CloudEnvironmentName' to default to 'Public', but got '%s'.", c.CloudEnvironmentName) } + if c.cloudEnvironment == nil || c.cloudEnvironment.Name != "AzurePublicCloud" { + t.Errorf("Expected 'cloudEnvironment' to be set to 'AzurePublicCloud', but got '%s'.", c.cloudEnvironment) + } +} + +func TestConfigInstantiatesCorrectAzureEnvironment(t *testing.T) { + config := map[string]string{ + "capture_name_prefix": "ignore", + "capture_container_name": "ignore", + "image_offer": "ignore", + "image_publisher": "ignore", + "image_sku": "ignore", + "location": "ignore", + "storage_account": "ignore", + "subscription_id": "ignore", + "os_type": constants.Target_Linux, + } + + // user input is fun :) + var table = []struct { + name string + environmentName string + }{ + {"China", "AzureChinaCloud"}, + {"ChinaCloud", "AzureChinaCloud"}, + {"AzureChinaCloud", "AzureChinaCloud"}, + {"aZuReChInAcLoUd", "AzureChinaCloud"}, + + {"USGovernment", "AzureUSGovernmentCloud"}, + {"USGovernmentCloud", "AzureUSGovernmentCloud"}, + {"AzureUSGovernmentCloud", "AzureUSGovernmentCloud"}, + {"aZuReUsGoVeRnMeNtClOuD", "AzureUSGovernmentCloud"}, + + {"Public", "AzurePublicCloud"}, + {"PublicCloud", "AzurePublicCloud"}, + {"AzurePublicCloud", "AzurePublicCloud"}, + {"aZuRePuBlIcClOuD", "AzurePublicCloud"}, + } + + packerConfiguration := getPackerConfiguration() + + for _, x := range table { + config["cloud_environment_name"] = x.name + c, _, _ := newConfig(config, packerConfiguration) + + if c.cloudEnvironment == nil || c.cloudEnvironment.Name != x.environmentName { + t.Errorf("Expected 'cloudEnvironment' to be set to '%s', but got '%s'.", x.environmentName, c.cloudEnvironment) + } + } +} + +func TestUserShouldProvideRequiredValues(t *testing.T) { + builderValues := getArmBuilderConfiguration() + // Ensure we can successfully create a config. - _, _, err := newConfig(getArmBuilderConfigurationFromMap(builderValues), getPackerConfiguration()) + _, _, err := newConfig(builderValues, getPackerConfiguration()) if err != nil { t.Errorf("Expected configuration creation to succeed, but it failed!\n") t.Fatalf(" -> %+v\n", builderValues) @@ -103,15 +170,16 @@ func TestUserShouldProvideRequiredValues(t *testing.T) { // Take away a required element, and ensure construction fails. for _, v := range requiredConfigValues { + originalValue := builderValues[v] delete(builderValues, v) - _, _, err := newConfig(getArmBuilderConfigurationFromMap(builderValues), getPackerConfiguration()) + _, _, err := newConfig(builderValues, getPackerConfiguration()) if err == nil { t.Errorf("Expected configuration creation to fail, but it succeeded!\n") t.Fatalf(" -> %+v\n", builderValues) } - builderValues[v] = "--some-value--" + builderValues[v] = originalValue } } @@ -167,8 +235,8 @@ func TestConfigShouldTransformToTemplateParameters(t *testing.T) { t.Errorf("Expected OSDiskName to be equal to config's OSDiskName, but they were '%s' and '%s' respectively.", templateParameters.OSDiskName.Value, c.tmpOSDiskName) } - if templateParameters.StorageAccountName.Value != c.StorageAccount { - t.Errorf("Expected StorageAccountName to be equal to config's StorageAccountName, but they were '%s' and '%s' respectively.", templateParameters.StorageAccountName.Value, c.StorageAccount) + if templateParameters.StorageAccountBlobEndpoint.Value != c.storageAccountBlobEndpoint { + t.Errorf("Expected StorageAccountBlobEndpoint to be equal to config's storageAccountBlobEndpoint, but they were '%s' and '%s' respectively.", templateParameters.StorageAccountBlobEndpoint.Value, c.storageAccountBlobEndpoint) } if templateParameters.VMName.Value != c.tmpComputeName { @@ -180,6 +248,50 @@ func TestConfigShouldTransformToTemplateParameters(t *testing.T) { } } +func TestConfigShouldTransformToTemplateParametersLinux(t *testing.T) { + c, _, _ := newConfig(getArmBuilderConfiguration(), getPackerConfiguration()) + c.OSType = constants.Target_Linux + templateParameters := c.toTemplateParameters() + + if templateParameters.KeyVaultSecretValue != nil { + t.Errorf("Expected KeyVaultSecretValue to be empty for an os_type == '%s', but it was not.", c.OSType) + } + + if templateParameters.ObjectId != nil { + t.Errorf("Expected ObjectId to be empty for an os_type == '%s', but it was not.", c.OSType) + } + + if templateParameters.TenantId != nil { + t.Errorf("Expected TenantId to be empty for an os_type == '%s', but it was not.", c.OSType) + } +} + +func TestConfigShouldTransformToTemplateParametersWindows(t *testing.T) { + c, _, _ := newConfig(getArmBuilderConfiguration(), getPackerConfiguration()) + c.OSType = constants.Target_Windows + templateParameters := c.toTemplateParameters() + + if templateParameters.SshAuthorizedKey != nil { + t.Errorf("Expected SshAuthorizedKey to be empty for an os_type == '%s', but it was not.", c.OSType) + } + + if templateParameters.KeyVaultName == nil { + t.Errorf("Expected KeyVaultName to not be empty for an os_type == '%s', but it was not.", c.OSType) + } + + if templateParameters.KeyVaultSecretValue == nil { + t.Errorf("Expected KeyVaultSecretValue to not be empty for an os_type == '%s', but it was not.", c.OSType) + } + + if templateParameters.ObjectId == nil { + t.Errorf("Expected ObjectId to not be empty for an os_type == '%s', but it was not.", c.OSType) + } + + if templateParameters.TenantId == nil { + t.Errorf("Expected TenantId to not be empty for an os_type == '%s', but it was not.", c.OSType) + } +} + func TestConfigShouldTransformToVirtualMachineCaptureParameters(t *testing.T) { c, _, _ := newConfig(getArmBuilderConfiguration(), getPackerConfiguration()) parameters := c.toVirtualMachineCaptureParameters() @@ -216,30 +328,67 @@ func TestConfigShouldSupportPackersConfigElements(t *testing.T) { } } -func getArmBuilderConfiguration() interface{} { +func TestUserDeviceLoginIsEnabledForLinux(t *testing.T) { + config := map[string]string{ + "capture_name_prefix": "ignore", + "capture_container_name": "ignore", + "image_offer": "ignore", + "image_publisher": "ignore", + "image_sku": "ignore", + "location": "ignore", + "storage_account": "ignore", + "subscription_id": "ignore", + "os_type": constants.Target_Linux, + } + + _, _, err := newConfig(config, getPackerConfiguration()) + if err != nil { + t.Fatalf("failed to use device login for Linux: %s", err) + } +} + +func TestUseDeviceLoginIsDisabledForWindows(t *testing.T) { + config := map[string]string{ + "capture_name_prefix": "ignore", + "capture_container_name": "ignore", + "image_offer": "ignore", + "image_publisher": "ignore", + "image_sku": "ignore", + "location": "ignore", + "storage_account": "ignore", + "subscription_id": "ignore", + "os_type": constants.Target_Windows, + } + + _, _, err := newConfig(config, getPackerConfiguration()) + if err == nil { + t.Fatalf("Expected test to fail, but it succeeded") + } + + multiError, _ := err.(*packer.MultiError) + if len(multiError.Errors) != 3 { + t.Errorf("Expected to find 3 errors, but found %d errors", len(multiError.Errors)) + } + + if !strings.Contains(err.Error(), "client_id must be specified") { + t.Errorf("Expected to find error for 'client_id must be specified") + } + if !strings.Contains(err.Error(), "client_secret must be specified") { + t.Errorf("Expected to find error for 'client_secret must be specified") + } + if !strings.Contains(err.Error(), "tenant_id must be specified") { + t.Errorf("Expected to find error for 'tenant_id must be specified") + } +} + +func getArmBuilderConfiguration() map[string]string { m := make(map[string]string) for _, v := range requiredConfigValues { m[v] = fmt.Sprintf("%s00", v) } - return getArmBuilderConfigurationFromMap(m) -} - -func getArmBuilderConfigurationFromMap(kvp map[string]string) interface{} { - bs := bytes.NewBufferString("{") - - for k, v := range kvp { - bs.WriteString(fmt.Sprintf("\"%s\": \"%s\",\n", k, v)) - } - - // remove the trailing ",\n" because JSON - bs.Truncate(bs.Len() - 2) - bs.WriteString("}") - - var config interface{} - json.Unmarshal([]byte(bs.String()), &config) - - return config + m["os_type"] = constants.Target_Linux + return m } func getPackerConfiguration() interface{} { @@ -260,14 +409,11 @@ func getPackerConfiguration() interface{} { return config } -func getPackerCommunicatorConfiguration() interface{} { - var doc = `{ - "ssh_timeout": "1h", - "winrm_timeout": "2h" - }` - - var config interface{} - json.Unmarshal([]byte(doc), &config) +func getPackerCommunicatorConfiguration() map[string]string { + config := map[string]string{ + "ssh_timeout": "1h", + "winrm_timeout": "2h", + } return config } diff --git a/builder/azure/arm/deployment_factory.go b/builder/azure/arm/deployment_factory.go index 80bcd914f..f563d8f4f 100644 --- a/builder/azure/arm/deployment_factory.go +++ b/builder/azure/arm/deployment_factory.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm diff --git a/builder/azure/arm/deployment_factory_test.go b/builder/azure/arm/deployment_factory_test.go index 9bc4a60dc..f3fe69996 100644 --- a/builder/azure/arm/deployment_factory_test.go +++ b/builder/azure/arm/deployment_factory_test.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm @@ -79,13 +79,13 @@ func TestMalformedTemplatesShouldReturnError(t *testing.T) { func getTemplateParameters() TemplateParameters { templateParameters := TemplateParameters{ - AdminUsername: &TemplateParameter{"adminusername00"}, - DnsNameForPublicIP: &TemplateParameter{"dnsnameforpublicip00"}, - OSDiskName: &TemplateParameter{"osdiskname00"}, - SshAuthorizedKey: &TemplateParameter{"sshkeydata00"}, - StorageAccountName: &TemplateParameter{"storageaccountname00"}, - VMName: &TemplateParameter{"vmname00"}, - VMSize: &TemplateParameter{"vmsize00"}, + AdminUsername: &TemplateParameter{"adminusername00"}, + DnsNameForPublicIP: &TemplateParameter{"dnsnameforpublicip00"}, + OSDiskName: &TemplateParameter{"osdiskname00"}, + SshAuthorizedKey: &TemplateParameter{"sshkeydata00"}, + StorageAccountBlobEndpoint: &TemplateParameter{"storageaccountblobendpoint00"}, + VMName: &TemplateParameter{"vmname00"}, + VMSize: &TemplateParameter{"vmsize00"}, } return templateParameters diff --git a/builder/azure/arm/deployment_poller.go b/builder/azure/arm/deployment_poller.go index b98fd8404..1ce76aa2b 100644 --- a/builder/azure/arm/deployment_poller.go +++ b/builder/azure/arm/deployment_poller.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm diff --git a/builder/azure/arm/deployment_poller_test.go b/builder/azure/arm/deployment_poller_test.go index 7ec25c0dd..83c90bc49 100644 --- a/builder/azure/arm/deployment_poller_test.go +++ b/builder/azure/arm/deployment_poller_test.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm diff --git a/builder/azure/arm/inspector.go b/builder/azure/arm/inspector.go new file mode 100644 index 000000000..2ad5ab807 --- /dev/null +++ b/builder/azure/arm/inspector.go @@ -0,0 +1,70 @@ +package arm + +import ( + "bytes" + "io/ioutil" + "log" + "net/http" + + "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/azure" + "github.com/mitchellh/packer/builder/azure/common/logutil" + "io" +) + +func chop(data []byte, maxlen int64) string { + s := string(data) + if int64(len(s)) > maxlen { + s = s[:maxlen] + "..." + } + return s +} + +func handleBody(body io.ReadCloser, maxlen int64) (io.ReadCloser, string) { + if body == nil { + return nil, "" + } + + defer body.Close() + + b, err := ioutil.ReadAll(body) + if err != nil { + return nil, "" + } + + return ioutil.NopCloser(bytes.NewReader(b)), chop(b, maxlen) +} + +func withInspection(maxlen int64) autorest.PrepareDecorator { + return func(p autorest.Preparer) autorest.Preparer { + return autorest.PreparerFunc(func(r *http.Request) (*http.Request, error) { + body, bodyString := handleBody(r.Body, maxlen) + r.Body = body + + log.Print("Azure request", logutil.Fields{ + "method": r.Method, + "request": r.URL.String(), + "body": bodyString, + }) + return p.Prepare(r) + }) + } +} + +func byInspecting(maxlen int64) autorest.RespondDecorator { + return func(r autorest.Responder) autorest.Responder { + return autorest.ResponderFunc(func(resp *http.Response) error { + body, bodyString := handleBody(resp.Body, maxlen) + resp.Body = body + + log.Print("Azure response", logutil.Fields{ + "status": resp.Status, + "method": resp.Request.Method, + "request": resp.Request.URL.String(), + "x-ms-request-id": azure.ExtractRequestID(resp), + "body": bodyString, + }) + return r.Respond(resp) + }) + } +} diff --git a/builder/azure/arm/openssh_key_pair.go b/builder/azure/arm/openssh_key_pair.go index f8529b8d8..a775525ae 100644 --- a/builder/azure/arm/openssh_key_pair.go +++ b/builder/azure/arm/openssh_key_pair.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm @@ -10,9 +10,8 @@ import ( "encoding/base64" "encoding/pem" "fmt" - "time" - "golang.org/x/crypto/ssh" + "time" ) const ( diff --git a/builder/azure/arm/openssh_key_pair_test.go b/builder/azure/arm/openssh_key_pair_test.go index eb4886ba2..2f538070c 100644 --- a/builder/azure/arm/openssh_key_pair_test.go +++ b/builder/azure/arm/openssh_key_pair_test.go @@ -1,14 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm import ( - "testing" - "golang.org/x/crypto/ssh" + "testing" ) +func TestFart(t *testing.T) { + +} + func TestAuthorizedKeyShouldParse(t *testing.T) { testSubject, err := NewOpenSshKeyPairWithSize(512) if err != nil { diff --git a/builder/azure/arm/step.go b/builder/azure/arm/step.go new file mode 100644 index 000000000..3b4893dbb --- /dev/null +++ b/builder/azure/arm/step.go @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. + +package arm + +import ( + "github.com/mitchellh/packer/builder/azure/common" + "github.com/mitchellh/packer/builder/azure/common/constants" + "github.com/mitchellh/multistep" +) + +func processInterruptibleResult( + result common.InterruptibleTaskResult, sayError func(error), state multistep.StateBag) multistep.StepAction { + if result.IsCancelled { + return multistep.ActionHalt + } + + return processStepResult(result.Err, sayError, state) +} + +func processStepResult( + err error, sayError func(error), state multistep.StateBag) multistep.StepAction { + + if err != nil { + state.Put(constants.Error, err) + sayError(err) + + return multistep.ActionHalt + } + + return multistep.ActionContinue + +} diff --git a/builder/azure/arm/step_capture_image.go b/builder/azure/arm/step_capture_image.go index 92b80caa0..aa21ef670 100644 --- a/builder/azure/arm/step_capture_image.go +++ b/builder/azure/arm/step_capture_image.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm @@ -7,14 +7,16 @@ import ( "fmt" "github.com/Azure/azure-sdk-for-go/arm/compute" - "github.com/mitchellh/multistep" + "github.com/mitchellh/packer/builder/azure/common" "github.com/mitchellh/packer/builder/azure/common/constants" + "github.com/mitchellh/multistep" "github.com/mitchellh/packer/packer" ) type StepCaptureImage struct { client *AzureClient - capture func(resourceGroupName string, computeName string, parameters *compute.VirtualMachineCaptureParameters) error + capture func(resourceGroupName string, computeName string, parameters *compute.VirtualMachineCaptureParameters, cancelCh <-chan struct{}) error + get func(client *AzureClient) *CaptureTemplate say func(message string) error func(e error) } @@ -22,6 +24,7 @@ type StepCaptureImage struct { func NewStepCaptureImage(client *AzureClient, ui packer.Ui) *StepCaptureImage { var step = &StepCaptureImage{ client: client, + get: func(client *AzureClient) *CaptureTemplate { return client.Template }, say: func(message string) { ui.Say(message) }, error: func(e error) { ui.Error(e.Error()) }, } @@ -30,20 +33,17 @@ func NewStepCaptureImage(client *AzureClient, ui packer.Ui) *StepCaptureImage { return step } -func (s *StepCaptureImage) captureImage(resourceGroupName string, computeName string, parameters *compute.VirtualMachineCaptureParameters) error { - generalizeResponse, err := s.client.Generalize(resourceGroupName, computeName) +func (s *StepCaptureImage) captureImage(resourceGroupName string, computeName string, parameters *compute.VirtualMachineCaptureParameters, cancelCh <-chan struct{}) error { + _, err := s.client.Generalize(resourceGroupName, computeName) if err != nil { return err } - s.client.VirtualMachinesClient.PollAsNeeded(generalizeResponse.Response) - - captureResponse, err := s.client.Capture(resourceGroupName, computeName, *parameters) + _, err = s.client.Capture(resourceGroupName, computeName, *parameters, cancelCh) if err != nil { return err } - s.client.VirtualMachinesClient.PollAsNeeded(captureResponse.Response.Response) return nil } @@ -57,15 +57,24 @@ func (s *StepCaptureImage) Run(state multistep.StateBag) multistep.StepAction { s.say(fmt.Sprintf(" -> ResourceGroupName : '%s'", resourceGroupName)) s.say(fmt.Sprintf(" -> ComputeName : '%s'", computeName)) - err := s.capture(resourceGroupName, computeName, parameters) - if err != nil { - state.Put(constants.Error, err) - s.error(err) + result := common.StartInterruptibleTask( + func() bool { return common.IsStateCancelled(state) }, + func(cancelCh <-chan struct{}) error { + return s.capture(resourceGroupName, computeName, parameters, cancelCh) + }) - return multistep.ActionHalt - } + // HACK(chrboum): I do not like this. The capture method should be returning this value + // instead having to pass in another lambda. I'm in this pickle because I am using + // common.StartInterruptibleTask which is not parametric, and only returns a type of error. + // I could change it to interface{}, but I do not like that solution either. + // + // Having to resort to capturing the template via an inspector is hack, and once I can + // resolve that I can cleanup this code too. See the comments in azure_client.go for more + // details. + template := s.get(s.client) + state.Put(constants.ArmCaptureTemplate, template) - return multistep.ActionContinue + return processInterruptibleResult(result, s.error, state) } func (*StepCaptureImage) Cleanup(multistep.StateBag) { diff --git a/builder/azure/arm/step_capture_image_test.go b/builder/azure/arm/step_capture_image_test.go index 653f48722..52c69b5da 100644 --- a/builder/azure/arm/step_capture_image_test.go +++ b/builder/azure/arm/step_capture_image_test.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm @@ -8,16 +8,18 @@ import ( "testing" "github.com/Azure/azure-sdk-for-go/arm/compute" - "github.com/mitchellh/multistep" "github.com/mitchellh/packer/builder/azure/common/constants" + "github.com/mitchellh/multistep" ) func TestStepCaptureImageShouldFailIfCaptureFails(t *testing.T) { - var testSubject = &StepCaptureImage{ - capture: func(string, string, *compute.VirtualMachineCaptureParameters) error { + capture: func(string, string, *compute.VirtualMachineCaptureParameters, <-chan struct{}) error { return fmt.Errorf("!! Unit Test FAIL !!") }, + get: func(client *AzureClient) *CaptureTemplate { + return nil + }, say: func(message string) {}, error: func(e error) {}, } @@ -36,9 +38,12 @@ func TestStepCaptureImageShouldFailIfCaptureFails(t *testing.T) { func TestStepCaptureImageShouldPassIfCapturePasses(t *testing.T) { var testSubject = &StepCaptureImage{ - capture: func(string, string, *compute.VirtualMachineCaptureParameters) error { return nil }, - say: func(message string) {}, - error: func(e error) {}, + capture: func(string, string, *compute.VirtualMachineCaptureParameters, <-chan struct{}) error { return nil }, + get: func(client *AzureClient) *CaptureTemplate { + return nil + }, + say: func(message string) {}, + error: func(e error) {}, } stateBag := createTestStateBagStepCaptureImage() @@ -54,18 +59,27 @@ func TestStepCaptureImageShouldPassIfCapturePasses(t *testing.T) { } func TestStepCaptureImageShouldTakeStepArgumentsFromStateBag(t *testing.T) { + cancelCh := make(chan<- struct{}) + defer close(cancelCh) + var actualResourceGroupName string var actualComputeName string var actualVirtualMachineCaptureParameters *compute.VirtualMachineCaptureParameters + actualCaptureTemplate := &CaptureTemplate{ + Schema: "!! Unit Test !!", + } var testSubject = &StepCaptureImage{ - capture: func(resourceGroupName string, computeName string, parameters *compute.VirtualMachineCaptureParameters) error { + capture: func(resourceGroupName string, computeName string, parameters *compute.VirtualMachineCaptureParameters, cancelCh <-chan struct{}) error { actualResourceGroupName = resourceGroupName actualComputeName = computeName actualVirtualMachineCaptureParameters = parameters return nil }, + get: func(client *AzureClient) *CaptureTemplate { + return actualCaptureTemplate + }, say: func(message string) {}, error: func(e error) {}, } @@ -80,6 +94,7 @@ func TestStepCaptureImageShouldTakeStepArgumentsFromStateBag(t *testing.T) { var expectedComputeName = stateBag.Get(constants.ArmComputeName).(string) var expectedResourceGroupName = stateBag.Get(constants.ArmResourceGroupName).(string) var expectedVirtualMachineCaptureParameters = stateBag.Get(constants.ArmVirtualMachineCaptureParameters).(*compute.VirtualMachineCaptureParameters) + var expectedCaptureTemplate = stateBag.Get(constants.ArmCaptureTemplate).(*CaptureTemplate) if actualComputeName != expectedComputeName { t.Fatalf("Expected StepCaptureImage to source 'constants.ArmComputeName' from the state bag, but it did not.") @@ -92,6 +107,10 @@ func TestStepCaptureImageShouldTakeStepArgumentsFromStateBag(t *testing.T) { if actualVirtualMachineCaptureParameters != expectedVirtualMachineCaptureParameters { t.Fatalf("Expected StepCaptureImage to source 'constants.ArmVirtualMachineCaptureParameters' from the state bag, but it did not.") } + + if actualCaptureTemplate != expectedCaptureTemplate { + t.Fatalf("Expected StepCaptureImage to source 'constants.ArmCaptureTemplate' from the state bag, but it did not.") + } } func createTestStateBagStepCaptureImage() multistep.StateBag { diff --git a/builder/azure/arm/step_create_resource_group.go b/builder/azure/arm/step_create_resource_group.go index 3a51bfd0e..85569c4a8 100644 --- a/builder/azure/arm/step_create_resource_group.go +++ b/builder/azure/arm/step_create_resource_group.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm @@ -7,8 +7,8 @@ import ( "fmt" "github.com/Azure/azure-sdk-for-go/arm/resources/resources" - "github.com/mitchellh/multistep" "github.com/mitchellh/packer/builder/azure/common/constants" + "github.com/mitchellh/multistep" "github.com/mitchellh/packer/packer" ) @@ -48,15 +48,29 @@ func (s *StepCreateResourceGroup) Run(state multistep.StateBag) multistep.StepAc s.say(fmt.Sprintf(" -> Location : '%s'", location)) err := s.create(resourceGroupName, location) - if err != nil { - state.Put(constants.Error, err) - s.error(err) - - return multistep.ActionHalt + if err == nil { + state.Put(constants.ArmIsResourceGroupCreated, true) } - return multistep.ActionContinue + return processStepResult(err, s.error, state) } -func (*StepCreateResourceGroup) Cleanup(multistep.StateBag) { +func (s *StepCreateResourceGroup) Cleanup(state multistep.StateBag) { + _, ok := state.GetOk(constants.ArmIsResourceGroupCreated) + if !ok { + return + } + + ui := state.Get("ui").(packer.Ui) + ui.Say("\nCleanup requested, deleting resource group ...") + + var resourceGroupName = state.Get(constants.ArmResourceGroupName).(string) + _, err := s.client.GroupsClient.Delete(resourceGroupName, nil) + if err != nil { + ui.Error(fmt.Sprintf("Error deleting resource group. Please delete it manually.\n\n"+ + "Name: %s\n"+ + "Error: %s", resourceGroupName, err)) + } + + ui.Say("Resource group has been deleted.") } diff --git a/builder/azure/arm/step_create_resource_group_test.go b/builder/azure/arm/step_create_resource_group_test.go index 262979e56..fe42a3129 100644 --- a/builder/azure/arm/step_create_resource_group_test.go +++ b/builder/azure/arm/step_create_resource_group_test.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm @@ -7,8 +7,8 @@ import ( "fmt" "testing" - "github.com/mitchellh/multistep" "github.com/mitchellh/packer/builder/azure/common/constants" + "github.com/mitchellh/multistep" ) func TestStepCreateResourceGroupShouldFailIfCreateFails(t *testing.T) { @@ -80,6 +80,11 @@ func TestStepCreateResourceGroupShouldTakeStepArgumentsFromStateBag(t *testing.T if actualLocation != expectedLocation { t.Fatalf("Expected the step to source 'constants.ArmResourceGroupName' from the state bag, but it did not.") } + + _, ok := stateBag.GetOk(constants.ArmIsResourceGroupCreated) + if !ok { + t.Fatalf("Expected the step to add item to stateBag['constants.ArmIsResourceGroupCreated'], but it did not.") + } } func createTestStateBagStepCreateResourceGroup() multistep.StateBag { diff --git a/builder/azure/arm/step_delete_os_disk.go b/builder/azure/arm/step_delete_os_disk.go index df3138c1a..1c996f9b4 100644 --- a/builder/azure/arm/step_delete_os_disk.go +++ b/builder/azure/arm/step_delete_os_disk.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm @@ -33,7 +33,7 @@ func NewStepDeleteOSDisk(client *AzureClient, ui packer.Ui) *StepDeleteOSDisk { } func (s *StepDeleteOSDisk) deleteBlob(storageContainerName string, blobName string) error { - return s.client.BlobStorageClient.DeleteBlob(storageContainerName, blobName) + return s.client.BlobStorageClient.DeleteBlob(storageContainerName, blobName, nil) } func (s *StepDeleteOSDisk) Run(state multistep.StateBag) multistep.StepAction { @@ -54,13 +54,7 @@ func (s *StepDeleteOSDisk) Run(state multistep.StateBag) multistep.StepAction { var blobName = strings.Join(xs[2:], "/") err = s.delete(storageAccountName, blobName) - if err != nil { - state.Put(constants.Error, err) - s.error(err) - - return multistep.ActionHalt - } - return multistep.ActionContinue + return processStepResult(err, s.error, state) } func (*StepDeleteOSDisk) Cleanup(multistep.StateBag) { diff --git a/builder/azure/arm/step_delete_os_disk_test.go b/builder/azure/arm/step_delete_os_disk_test.go index 2355bf283..214cfec74 100644 --- a/builder/azure/arm/step_delete_os_disk_test.go +++ b/builder/azure/arm/step_delete_os_disk_test.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm @@ -7,8 +7,8 @@ import ( "fmt" "testing" - "github.com/mitchellh/multistep" "github.com/mitchellh/packer/builder/azure/common/constants" + "github.com/mitchellh/multistep" ) func TestStepDeleteOSDiskShouldFailIfGetFails(t *testing.T) { diff --git a/builder/azure/arm/step_delete_resource_group.go b/builder/azure/arm/step_delete_resource_group.go index 724b945ea..62e67919e 100644 --- a/builder/azure/arm/step_delete_resource_group.go +++ b/builder/azure/arm/step_delete_resource_group.go @@ -1,19 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm import ( "fmt" - "github.com/mitchellh/multistep" + "github.com/mitchellh/packer/builder/azure/common" "github.com/mitchellh/packer/builder/azure/common/constants" + "github.com/mitchellh/multistep" "github.com/mitchellh/packer/packer" ) type StepDeleteResourceGroup struct { client *AzureClient - delete func(resourceGroupName string) error + delete func(resourceGroupName string, cancelCh <-chan struct{}) error say func(message string) error func(e error) } @@ -29,32 +30,23 @@ func NewStepDeleteResourceGroup(client *AzureClient, ui packer.Ui) *StepDeleteRe return step } -func (s *StepDeleteResourceGroup) deleteResourceGroup(resourceGroupName string) error { - res, err := s.client.GroupsClient.Delete(resourceGroupName) - if err != nil { - return err - } +func (s *StepDeleteResourceGroup) deleteResourceGroup(resourceGroupName string, cancelCh <-chan struct{}) error { + _, err := s.client.GroupsClient.Delete(resourceGroupName, cancelCh) - s.client.GroupsClient.PollAsNeeded(res.Response) - return nil + return err } func (s *StepDeleteResourceGroup) Run(state multistep.StateBag) multistep.StepAction { s.say("Deleting resource group ...") var resourceGroupName = state.Get(constants.ArmResourceGroupName).(string) - s.say(fmt.Sprintf(" -> ResourceGroupName : '%s'", resourceGroupName)) - err := s.delete(resourceGroupName) - if err != nil { - state.Put(constants.Error, err) - s.error(err) + result := common.StartInterruptibleTask( + func() bool { return common.IsStateCancelled(state) }, + func(cancelCh <-chan struct{}) error { return s.delete(resourceGroupName, cancelCh) }) - return multistep.ActionHalt - } - - return multistep.ActionContinue + return processInterruptibleResult(result, s.error, state) } func (*StepDeleteResourceGroup) Cleanup(multistep.StateBag) { diff --git a/builder/azure/arm/step_delete_resource_group_test.go b/builder/azure/arm/step_delete_resource_group_test.go index e99e89600..3ba2dbf17 100644 --- a/builder/azure/arm/step_delete_resource_group_test.go +++ b/builder/azure/arm/step_delete_resource_group_test.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm @@ -7,13 +7,13 @@ import ( "fmt" "testing" - "github.com/mitchellh/multistep" "github.com/mitchellh/packer/builder/azure/common/constants" + "github.com/mitchellh/multistep" ) func TestStepDeleteResourceGroupShouldFailIfDeleteFails(t *testing.T) { var testSubject = &StepDeleteResourceGroup{ - delete: func(string) error { return fmt.Errorf("!! Unit Test FAIL !!") }, + delete: func(string, <-chan struct{}) error { return fmt.Errorf("!! Unit Test FAIL !!") }, say: func(message string) {}, error: func(e error) {}, } @@ -32,7 +32,7 @@ func TestStepDeleteResourceGroupShouldFailIfDeleteFails(t *testing.T) { func TestStepDeleteResourceGroupShouldPassIfDeletePasses(t *testing.T) { var testSubject = &StepDeleteResourceGroup{ - delete: func(string) error { return nil }, + delete: func(string, <-chan struct{}) error { return nil }, say: func(message string) {}, error: func(e error) {}, } @@ -53,7 +53,7 @@ func TestStepDeleteResourceGroupShouldTakeStepArgumentsFromStateBag(t *testing.T var actualResourceGroupName string var testSubject = &StepDeleteResourceGroup{ - delete: func(resourceGroupName string) error { + delete: func(resourceGroupName string, cancelCh <-chan struct{}) error { actualResourceGroupName = resourceGroupName return nil }, diff --git a/builder/azure/arm/step_deploy_template.go b/builder/azure/arm/step_deploy_template.go index 7ad9b8578..7ec7e484e 100644 --- a/builder/azure/arm/step_deploy_template.go +++ b/builder/azure/arm/step_deploy_template.go @@ -1,47 +1,49 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm import ( "fmt" - "github.com/mitchellh/multistep" + "github.com/mitchellh/packer/builder/azure/common" "github.com/mitchellh/packer/builder/azure/common/constants" + "github.com/mitchellh/multistep" "github.com/mitchellh/packer/packer" ) type StepDeployTemplate struct { - client *AzureClient - deploy func(resourceGroupName string, deploymentName string, templateParameters *TemplateParameters) error - say func(message string) - error func(e error) + client *AzureClient + template string + deploy func(resourceGroupName string, deploymentName string, templateParameters *TemplateParameters, cancelCh <-chan struct{}) error + say func(message string) + error func(e error) } -func NewStepDeployTemplate(client *AzureClient, ui packer.Ui) *StepDeployTemplate { +func NewStepDeployTemplate(client *AzureClient, ui packer.Ui, template string) *StepDeployTemplate { var step = &StepDeployTemplate{ - client: client, - say: func(message string) { ui.Say(message) }, - error: func(e error) { ui.Error(e.Error()) }, + client: client, + template: template, + say: func(message string) { ui.Say(message) }, + error: func(e error) { ui.Error(e.Error()) }, } step.deploy = step.deployTemplate return step } -func (s *StepDeployTemplate) deployTemplate(resourceGroupName string, deploymentName string, templateParameters *TemplateParameters) error { - factory := newDeploymentFactory(Linux) +func (s *StepDeployTemplate) deployTemplate(resourceGroupName string, deploymentName string, templateParameters *TemplateParameters, cancelCh <-chan struct{}) error { + factory := newDeploymentFactory(s.template) deployment, err := factory.create(*templateParameters) if err != nil { return err } - res, err := s.client.DeploymentsClient.CreateOrUpdate(resourceGroupName, deploymentName, *deployment) + _, err = s.client.DeploymentsClient.CreateOrUpdate(resourceGroupName, deploymentName, *deployment, cancelCh) if err != nil { return err } - s.client.DeploymentsClient.PollAsNeeded(res.Response.Response) poller := NewDeploymentPoller(func() (string, error) { r, e := s.client.DeploymentsClient.Get(resourceGroupName, deploymentName) if r.Properties != nil && r.Properties.ProvisioningState != nil { @@ -73,15 +75,14 @@ func (s *StepDeployTemplate) Run(state multistep.StateBag) multistep.StepAction s.say(fmt.Sprintf(" -> ResourceGroupName : '%s'", resourceGroupName)) s.say(fmt.Sprintf(" -> DeploymentName : '%s'", deploymentName)) - err := s.deploy(resourceGroupName, deploymentName, templateParameters) - if err != nil { - state.Put(constants.Error, err) - s.error(err) + result := common.StartInterruptibleTask( + func() bool { return common.IsStateCancelled(state) }, + func(cancelCh <-chan struct{}) error { + return s.deploy(resourceGroupName, deploymentName, templateParameters, cancelCh) + }, + ) - return multistep.ActionHalt - } - - return multistep.ActionContinue + return processInterruptibleResult(result, s.error, state) } func (*StepDeployTemplate) Cleanup(multistep.StateBag) { diff --git a/builder/azure/arm/step_deploy_template_test.go b/builder/azure/arm/step_deploy_template_test.go index 6f61ac2e2..330fdb8c8 100644 --- a/builder/azure/arm/step_deploy_template_test.go +++ b/builder/azure/arm/step_deploy_template_test.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm @@ -7,15 +7,18 @@ import ( "fmt" "testing" - "github.com/mitchellh/multistep" "github.com/mitchellh/packer/builder/azure/common/constants" + "github.com/mitchellh/multistep" ) func TestStepDeployTemplateShouldFailIfDeployFails(t *testing.T) { var testSubject = &StepDeployTemplate{ - deploy: func(string, string, *TemplateParameters) error { return fmt.Errorf("!! Unit Test FAIL !!") }, - say: func(message string) {}, - error: func(e error) {}, + template: Linux, + deploy: func(string, string, *TemplateParameters, <-chan struct{}) error { + return fmt.Errorf("!! Unit Test FAIL !!") + }, + say: func(message string) {}, + error: func(e error) {}, } stateBag := createTestStateBagStepDeployTemplate() @@ -32,9 +35,10 @@ func TestStepDeployTemplateShouldFailIfDeployFails(t *testing.T) { func TestStepDeployTemplateShouldPassIfDeployPasses(t *testing.T) { var testSubject = &StepDeployTemplate{ - deploy: func(string, string, *TemplateParameters) error { return nil }, - say: func(message string) {}, - error: func(e error) {}, + template: Linux, + deploy: func(string, string, *TemplateParameters, <-chan struct{}) error { return nil }, + say: func(message string) {}, + error: func(e error) {}, } stateBag := createTestStateBagStepDeployTemplate() @@ -55,7 +59,8 @@ func TestStepDeployTemplateShouldTakeStepArgumentsFromStateBag(t *testing.T) { var actualTemplateParameters *TemplateParameters var testSubject = &StepDeployTemplate{ - deploy: func(resourceGroupName string, deploymentName string, templateParameter *TemplateParameters) error { + template: Linux, + deploy: func(resourceGroupName string, deploymentName string, templateParameter *TemplateParameters, cancelCh <-chan struct{}) error { actualResourceGroupName = resourceGroupName actualDeploymentName = deploymentName actualTemplateParameters = templateParameter diff --git a/builder/azure/arm/step_get_certificate.go b/builder/azure/arm/step_get_certificate.go new file mode 100644 index 000000000..f8fc89940 --- /dev/null +++ b/builder/azure/arm/step_get_certificate.go @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. + +package arm + +import ( + "fmt" + "time" + + "github.com/mitchellh/packer/builder/azure/common/constants" + "github.com/mitchellh/multistep" + "github.com/mitchellh/packer/packer" +) + +type StepGetCertificate struct { + client *AzureClient + template string + get func(keyVaultName string, secretName string) (string, error) + say func(message string) + error func(e error) + pause func() +} + +func NewStepGetCertificate(client *AzureClient, ui packer.Ui) *StepGetCertificate { + var step = &StepGetCertificate{ + client: client, + say: func(message string) { ui.Say(message) }, + error: func(e error) { ui.Error(e.Error()) }, + pause: func() { time.Sleep(30 * time.Second) }, + } + + step.get = step.getCertificateUrl + return step +} + +func (s *StepGetCertificate) getCertificateUrl(keyVaultName string, secretName string) (string, error) { + secret, err := s.client.GetSecret(keyVaultName, secretName) + if err != nil { + return "", err + } + + return *secret.ID, err +} + +func (s *StepGetCertificate) Run(state multistep.StateBag) multistep.StepAction { + s.say("Getting the certificate's URL ...") + + var keyVaultName = state.Get(constants.ArmKeyVaultName).(string) + + s.say(fmt.Sprintf(" -> Key Vault Name : '%s'", keyVaultName)) + s.say(fmt.Sprintf(" -> Key Vault Secret Name : '%s'", DefaultSecretName)) + + var err error + var url string + for i := 0; i < 5; i++ { + url, err = s.get(keyVaultName, DefaultSecretName) + if err == nil { + break + } + + s.say(fmt.Sprintf(" ...failed to get certificate URL, retry(%d)", i)) + s.pause() + } + + if err != nil { + state.Put(constants.Error, err) + s.error(err) + + return multistep.ActionHalt + } + + s.say(fmt.Sprintf(" -> Certificate URL : '%s'", url)) + state.Put(constants.ArmCertificateUrl, url) + + return multistep.ActionContinue +} + +func (*StepGetCertificate) Cleanup(multistep.StateBag) { +} diff --git a/builder/azure/arm/step_get_certificate_test.go b/builder/azure/arm/step_get_certificate_test.go new file mode 100644 index 000000000..7db490ddc --- /dev/null +++ b/builder/azure/arm/step_get_certificate_test.go @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. + +package arm + +import ( + "fmt" + "testing" + + "github.com/mitchellh/packer/builder/azure/common/constants" + "github.com/mitchellh/multistep" +) + +func TestStepGetCertificateShouldFailIfGetFails(t *testing.T) { + var testSubject = &StepGetCertificate{ + get: func(string, string) (string, error) { return "", fmt.Errorf("!! Unit Test FAIL !!") }, + say: func(message string) {}, + error: func(e error) {}, + pause: func() {}, + } + + stateBag := createTestStateBagStepGetCertificate() + + var result = testSubject.Run(stateBag) + if result != multistep.ActionHalt { + t.Fatalf("Expected the step to return 'ActionHalt', but got '%d'.", result) + } + + if _, ok := stateBag.GetOk(constants.Error); ok == false { + t.Fatalf("Expected the step to set stateBag['%s'], but it was not.", constants.Error) + } +} + +func TestStepGetCertificateShouldPassIfGetPasses(t *testing.T) { + var testSubject = &StepGetCertificate{ + get: func(string, string) (string, error) { return "", nil }, + say: func(message string) {}, + error: func(e error) {}, + pause: func() {}, + } + + stateBag := createTestStateBagStepGetCertificate() + + var result = testSubject.Run(stateBag) + if result != multistep.ActionContinue { + t.Fatalf("Expected the step to return 'ActionContinue', but got '%d'.", result) + } + + if _, ok := stateBag.GetOk(constants.Error); ok == true { + t.Fatalf("Expected the step to not set stateBag['%s'], but it was.", constants.Error) + } +} + +func TestStepGetCertificateShouldTakeStepArgumentsFromStateBag(t *testing.T) { + var actualKeyVaultName string + var actualSecretName string + + var testSubject = &StepGetCertificate{ + get: func(keyVaultName string, secretName string) (string, error) { + actualKeyVaultName = keyVaultName + actualSecretName = secretName + + return "http://key.vault/1", nil + }, + say: func(message string) {}, + error: func(e error) {}, + pause: func() {}, + } + + stateBag := createTestStateBagStepGetCertificate() + var result = testSubject.Run(stateBag) + + if result != multistep.ActionContinue { + t.Fatalf("Expected the step to return 'ActionContinue', but got '%d'.", result) + } + + var expectedKeyVaultName = stateBag.Get(constants.ArmKeyVaultName).(string) + + if actualKeyVaultName != expectedKeyVaultName { + t.Fatalf("Expected StepGetCertificate to source 'constants.ArmKeyVaultName' from the state bag, but it did not.") + } + if actualSecretName != DefaultSecretName { + t.Fatalf("Expected StepGetCertificate to use default value for secret, but it did not.") + } + + expectedCertificateUrl, ok := stateBag.GetOk(constants.ArmCertificateUrl) + if !ok { + t.Fatalf("Expected the state bag to have a value for '%s', but it did not.", constants.ArmCertificateUrl) + } + + if expectedCertificateUrl != "http://key.vault/1" { + t.Fatalf("Expected the value of stateBag[%s] to be 'http://key.vault/1', but got '%s'.", constants.ArmCertificateUrl, expectedCertificateUrl) + } +} + +func createTestStateBagStepGetCertificate() multistep.StateBag { + stateBag := new(multistep.BasicStateBag) + stateBag.Put(constants.ArmKeyVaultName, "Unit Test: KeyVaultName") + + return stateBag +} diff --git a/builder/azure/arm/step_get_ip_address.go b/builder/azure/arm/step_get_ip_address.go index 43ebf4778..5f8be6bd8 100644 --- a/builder/azure/arm/step_get_ip_address.go +++ b/builder/azure/arm/step_get_ip_address.go @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm import ( "fmt" - "github.com/mitchellh/multistep" "github.com/mitchellh/packer/builder/azure/common/constants" + "github.com/mitchellh/multistep" "github.com/mitchellh/packer/packer" ) @@ -55,7 +55,7 @@ func (s *StepGetIPAddress) Run(state multistep.StateBag) multistep.StepAction { return multistep.ActionHalt } - s.say(fmt.Sprintf(" -> SSHHost : '%s'", address)) + s.say(fmt.Sprintf(" -> Public IP : '%s'", address)) state.Put(constants.SSHHost, address) return multistep.ActionContinue diff --git a/builder/azure/arm/step_get_ip_address_test.go b/builder/azure/arm/step_get_ip_address_test.go index 2a3749033..14d4d73fe 100644 --- a/builder/azure/arm/step_get_ip_address_test.go +++ b/builder/azure/arm/step_get_ip_address_test.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm @@ -7,8 +7,8 @@ import ( "fmt" "testing" - "github.com/mitchellh/multistep" "github.com/mitchellh/packer/builder/azure/common/constants" + "github.com/mitchellh/multistep" ) func TestStepGetIPAddressShouldFailIfGetFails(t *testing.T) { @@ -75,11 +75,11 @@ func TestStepGetIPAddressShouldTakeStepArgumentsFromStateBag(t *testing.T) { var expectedIPAddressName = stateBag.Get(constants.ArmPublicIPAddressName).(string) if actualIPAddressName != expectedIPAddressName { - t.Fatalf("Expected StepValidateTemplate to source 'constants.ArmIPAddressName' from the state bag, but it did not.") + t.Fatalf("Expected StepGetIPAddress to source 'constants.ArmIPAddressName' from the state bag, but it did not.") } if actualResourceGroupName != expectedResourceGroupName { - t.Fatalf("Expected StepValidateTemplate to source 'constants.ArmResourceGroupName' from the state bag, but it did not.") + t.Fatalf("Expected StepGetIPAddress to source 'constants.ArmResourceGroupName' from the state bag, but it did not.") } expectedIPAddress, ok := stateBag.GetOk(constants.SSHHost) diff --git a/builder/azure/arm/step_get_os_disk.go b/builder/azure/arm/step_get_os_disk.go index 810159ea2..4d67057db 100644 --- a/builder/azure/arm/step_get_os_disk.go +++ b/builder/azure/arm/step_get_os_disk.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm diff --git a/builder/azure/arm/step_get_os_disk_test.go b/builder/azure/arm/step_get_os_disk_test.go index d0cc45f07..a0219762c 100644 --- a/builder/azure/arm/step_get_os_disk_test.go +++ b/builder/azure/arm/step_get_os_disk_test.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm @@ -95,7 +95,7 @@ func TestStepGetOSDiskShouldTakeValidateArgumentsFromStateBag(t *testing.T) { } if expectedOSDiskVhd != "test.vhd" { - t.Fatalf("Expected the value of stateBag[%s] to be '127.0.0.1', but got '%s'.", constants.ArmOSDiskVhd, expectedOSDiskVhd) + t.Fatalf("Expected the value of stateBag[%s] to be 'test.vhd', but got '%s'.", constants.ArmOSDiskVhd, expectedOSDiskVhd) } } diff --git a/builder/azure/arm/step_power_off_compute.go b/builder/azure/arm/step_power_off_compute.go index 9021d2e21..bd6044745 100644 --- a/builder/azure/arm/step_power_off_compute.go +++ b/builder/azure/arm/step_power_off_compute.go @@ -1,19 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm import ( "fmt" - "github.com/mitchellh/multistep" + "github.com/mitchellh/packer/builder/azure/common" "github.com/mitchellh/packer/builder/azure/common/constants" + "github.com/mitchellh/multistep" "github.com/mitchellh/packer/packer" ) type StepPowerOffCompute struct { client *AzureClient - powerOff func(resourceGroupName string, computeName string) error + powerOff func(resourceGroupName string, computeName string, cancelCh <-chan struct{}) error say func(message string) error func(e error) } @@ -29,13 +30,12 @@ func NewStepPowerOffCompute(client *AzureClient, ui packer.Ui) *StepPowerOffComp return step } -func (s *StepPowerOffCompute) powerOffCompute(resourceGroupName string, computeName string) error { - res, err := s.client.PowerOff(resourceGroupName, computeName) +func (s *StepPowerOffCompute) powerOffCompute(resourceGroupName string, computeName string, cancelCh <-chan struct{}) error { + _, err := s.client.PowerOff(resourceGroupName, computeName, cancelCh) if err != nil { return err } - s.client.VirtualMachinesClient.PollAsNeeded(res.Response) return nil } @@ -48,15 +48,11 @@ func (s *StepPowerOffCompute) Run(state multistep.StateBag) multistep.StepAction s.say(fmt.Sprintf(" -> ResourceGroupName : '%s'", resourceGroupName)) s.say(fmt.Sprintf(" -> ComputeName : '%s'", computeName)) - err := s.powerOff(resourceGroupName, computeName) - if err != nil { - state.Put(constants.Error, err) - s.error(err) + result := common.StartInterruptibleTask( + func() bool { return common.IsStateCancelled(state) }, + func(cancelCh <-chan struct{}) error { return s.powerOff(resourceGroupName, computeName, cancelCh) }) - return multistep.ActionHalt - } - - return multistep.ActionContinue + return processInterruptibleResult(result, s.error, state) } func (*StepPowerOffCompute) Cleanup(multistep.StateBag) { diff --git a/builder/azure/arm/step_power_off_compute_test.go b/builder/azure/arm/step_power_off_compute_test.go index 64553e3a0..244265d46 100644 --- a/builder/azure/arm/step_power_off_compute_test.go +++ b/builder/azure/arm/step_power_off_compute_test.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm @@ -7,13 +7,13 @@ import ( "fmt" "testing" - "github.com/mitchellh/multistep" "github.com/mitchellh/packer/builder/azure/common/constants" + "github.com/mitchellh/multistep" ) func TestStepPowerOffComputeShouldFailIfPowerOffFails(t *testing.T) { var testSubject = &StepPowerOffCompute{ - powerOff: func(string, string) error { return fmt.Errorf("!! Unit Test FAIL !!") }, + powerOff: func(string, string, <-chan struct{}) error { return fmt.Errorf("!! Unit Test FAIL !!") }, say: func(message string) {}, error: func(e error) {}, } @@ -32,7 +32,7 @@ func TestStepPowerOffComputeShouldFailIfPowerOffFails(t *testing.T) { func TestStepPowerOffComputeShouldPassIfPowerOffPasses(t *testing.T) { var testSubject = &StepPowerOffCompute{ - powerOff: func(string, string) error { return nil }, + powerOff: func(string, string, <-chan struct{}) error { return nil }, say: func(message string) {}, error: func(e error) {}, } @@ -54,7 +54,7 @@ func TestStepPowerOffComputeShouldTakeStepArgumentsFromStateBag(t *testing.T) { var actualComputeName string var testSubject = &StepPowerOffCompute{ - powerOff: func(resourceGroupName string, computeName string) error { + powerOff: func(resourceGroupName string, computeName string, cancelCh <-chan struct{}) error { actualResourceGroupName = resourceGroupName actualComputeName = computeName diff --git a/builder/azure/arm/step_set_certificate.go b/builder/azure/arm/step_set_certificate.go new file mode 100644 index 000000000..4bf6cc516 --- /dev/null +++ b/builder/azure/arm/step_set_certificate.go @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. + +package arm + +import ( + "github.com/mitchellh/packer/builder/azure/common/constants" + "github.com/mitchellh/multistep" + "github.com/mitchellh/packer/packer" +) + +type StepSetCertificate struct { + config *Config + say func(message string) + error func(e error) +} + +func NewStepSetCertificate(config *Config, ui packer.Ui) *StepSetCertificate { + var step = &StepSetCertificate{ + config: config, + say: func(message string) { ui.Say(message) }, + error: func(e error) { ui.Error(e.Error()) }, + } + + return step +} + +func (s *StepSetCertificate) Run(state multistep.StateBag) multistep.StepAction { + s.say("Setting the certificate's URL ...") + + var winRMCertificateUrl = state.Get(constants.ArmCertificateUrl).(string) + s.config.tmpWinRMCertificateUrl = winRMCertificateUrl + + state.Put(constants.ArmTemplateParameters, s.config.toTemplateParameters()) + return multistep.ActionContinue +} + +func (*StepSetCertificate) Cleanup(multistep.StateBag) { +} diff --git a/builder/azure/arm/step_set_certificate_test.go b/builder/azure/arm/step_set_certificate_test.go new file mode 100644 index 000000000..4ee8c0cb4 --- /dev/null +++ b/builder/azure/arm/step_set_certificate_test.go @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. + +package arm + +import ( + "testing" + + "github.com/mitchellh/packer/builder/azure/common/constants" + "github.com/mitchellh/multistep" +) + +func TestStepSetCertificateShouldPassIfGetPasses(t *testing.T) { + var testSubject = &StepSetCertificate{ + config: new(Config), + say: func(message string) {}, + error: func(e error) {}, + } + + stateBag := createTestStateBagStepSetCertificate() + + var result = testSubject.Run(stateBag) + if result != multistep.ActionContinue { + t.Fatalf("Expected the step to return 'ActionContinue', but got '%d'.", result) + } + + if _, ok := stateBag.GetOk(constants.Error); ok == true { + t.Fatalf("Expected the step to not set stateBag['%s'], but it was.", constants.Error) + } +} + +func TestStepSetCertificateShouldTakeStepArgumentsFromStateBag(t *testing.T) { + var testSubject = &StepSetCertificate{ + config: new(Config), + say: func(message string) {}, + error: func(e error) {}, + } + + stateBag := createTestStateBagStepSetCertificate() + var result = testSubject.Run(stateBag) + + if result != multistep.ActionContinue { + t.Fatalf("Expected the step to return 'ActionContinue', but got '%d'.", result) + } + + _, ok := stateBag.GetOk(constants.ArmTemplateParameters) + if !ok { + t.Fatalf("Expected the state bag to have a value for '%s', but it did not.", constants.ArmTemplateParameters) + } +} + +func createTestStateBagStepSetCertificate() multistep.StateBag { + stateBag := new(multistep.BasicStateBag) + stateBag.Put(constants.ArmCertificateUrl, "Unit Test: Certificate URL") + return stateBag +} diff --git a/builder/azure/arm/step_test.go b/builder/azure/arm/step_test.go new file mode 100644 index 000000000..c3c1e9438 --- /dev/null +++ b/builder/azure/arm/step_test.go @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. + +package arm + +import ( + "fmt" + "github.com/mitchellh/packer/builder/azure/common" + "github.com/mitchellh/packer/builder/azure/common/constants" + "github.com/mitchellh/multistep" + "testing" +) + +func TestProcessStepResultShouldContinueForNonErrors(t *testing.T) { + stateBag := new(multistep.BasicStateBag) + + code := processStepResult(nil, func(error) { t.Fatalf("Should not be called!") }, stateBag) + if _, ok := stateBag.GetOk(constants.Error); ok { + t.Errorf("Error was nil, but was still in the state bag.") + } + + if code != multistep.ActionContinue { + t.Errorf("Expected ActionContinue(%d), but got=%d", multistep.ActionContinue, code) + } +} + +func TestProcessStepResultShouldHaltOnError(t *testing.T) { + stateBag := new(multistep.BasicStateBag) + isSaidError := false + + code := processStepResult(fmt.Errorf("boom"), func(error) { isSaidError = true }, stateBag) + if _, ok := stateBag.GetOk(constants.Error); !ok { + t.Errorf("Error was non nil, but was not in the state bag.") + } + + if !isSaidError { + t.Errorf("Expected error to be said, but it was not.") + } + + if code != multistep.ActionHalt { + t.Errorf("Expected ActionHalt(%d), but got=%d", multistep.ActionHalt, code) + } +} + +func TestProcessStepResultShouldContinueOnSuccessfulTask(t *testing.T) { + stateBag := new(multistep.BasicStateBag) + result := common.InterruptibleTaskResult{ + IsCancelled: false, + Err: nil, + } + + code := processInterruptibleResult(result, func(error) { t.Fatalf("Should not be called!") }, stateBag) + if _, ok := stateBag.GetOk(constants.Error); ok { + t.Errorf("Error was nil, but was still in the state bag.") + } + + if code != multistep.ActionContinue { + t.Errorf("Expected ActionContinue(%d), but got=%d", multistep.ActionContinue, code) + } +} + +func TestProcessStepResultShouldHaltWhenTaskIsCancelled(t *testing.T) { + stateBag := new(multistep.BasicStateBag) + result := common.InterruptibleTaskResult{ + IsCancelled: true, + Err: nil, + } + + code := processInterruptibleResult(result, func(error) { t.Fatalf("Should not be called!") }, stateBag) + if _, ok := stateBag.GetOk(constants.Error); ok { + t.Errorf("Error was nil, but was still in the state bag.") + } + + if code != multistep.ActionHalt { + t.Errorf("Expected ActionHalt(%d), but got=%d", multistep.ActionHalt, code) + } +} + +func TestProcessStepResultShouldHaltOnTaskError(t *testing.T) { + stateBag := new(multistep.BasicStateBag) + isSaidError := false + result := common.InterruptibleTaskResult{ + IsCancelled: false, + Err: fmt.Errorf("boom"), + } + + code := processInterruptibleResult(result, func(error) { isSaidError = true }, stateBag) + if _, ok := stateBag.GetOk(constants.Error); !ok { + t.Errorf("Error was *not* nil, but was not in the state bag.") + } + + if !isSaidError { + t.Errorf("Expected error to be said, but it was not.") + } + + if code != multistep.ActionHalt { + t.Errorf("Expected ActionHalt(%d), but got=%d", multistep.ActionHalt, code) + } +} diff --git a/builder/azure/arm/step_validate_template.go b/builder/azure/arm/step_validate_template.go index 1a083b9a0..52f65e8ac 100644 --- a/builder/azure/arm/step_validate_template.go +++ b/builder/azure/arm/step_validate_template.go @@ -1,28 +1,30 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm import ( "fmt" - "github.com/mitchellh/multistep" "github.com/mitchellh/packer/builder/azure/common/constants" + "github.com/mitchellh/multistep" "github.com/mitchellh/packer/packer" ) type StepValidateTemplate struct { client *AzureClient + template string validate func(resourceGroupName string, deploymentName string, templateParameters *TemplateParameters) error say func(message string) error func(e error) } -func NewStepValidateTemplate(client *AzureClient, ui packer.Ui) *StepValidateTemplate { +func NewStepValidateTemplate(client *AzureClient, ui packer.Ui, template string) *StepValidateTemplate { var step = &StepValidateTemplate{ - client: client, - say: func(message string) { ui.Say(message) }, - error: func(e error) { ui.Error(e.Error()) }, + client: client, + template: template, + say: func(message string) { ui.Say(message) }, + error: func(e error) { ui.Error(e.Error()) }, } step.validate = step.validateTemplate @@ -30,7 +32,7 @@ func NewStepValidateTemplate(client *AzureClient, ui packer.Ui) *StepValidateTem } func (s *StepValidateTemplate) validateTemplate(resourceGroupName string, deploymentName string, templateParameters *TemplateParameters) error { - factory := newDeploymentFactory(Linux) + factory := newDeploymentFactory(s.template) deployment, err := factory.create(*templateParameters) if err != nil { @@ -52,14 +54,7 @@ func (s *StepValidateTemplate) Run(state multistep.StateBag) multistep.StepActio s.say(fmt.Sprintf(" -> DeploymentName : '%s'", deploymentName)) err := s.validate(resourceGroupName, deploymentName, templateParameters) - if err != nil { - state.Put(constants.Error, err) - s.error(err) - - return multistep.ActionHalt - } - - return multistep.ActionContinue + return processStepResult(err, s.error, state) } func (*StepValidateTemplate) Cleanup(multistep.StateBag) { diff --git a/builder/azure/arm/step_validate_template_test.go b/builder/azure/arm/step_validate_template_test.go index 1242666fa..d2c76e0de 100644 --- a/builder/azure/arm/step_validate_template_test.go +++ b/builder/azure/arm/step_validate_template_test.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm @@ -7,13 +7,14 @@ import ( "fmt" "testing" - "github.com/mitchellh/multistep" "github.com/mitchellh/packer/builder/azure/common/constants" + "github.com/mitchellh/multistep" ) func TestStepValidateTemplateShouldFailIfValidateFails(t *testing.T) { var testSubject = &StepValidateTemplate{ + template: Linux, validate: func(string, string, *TemplateParameters) error { return fmt.Errorf("!! Unit Test FAIL !!") }, say: func(message string) {}, error: func(e error) {}, @@ -33,6 +34,7 @@ func TestStepValidateTemplateShouldFailIfValidateFails(t *testing.T) { func TestStepValidateTemplateShouldPassIfValidatePasses(t *testing.T) { var testSubject = &StepValidateTemplate{ + template: Linux, validate: func(string, string, *TemplateParameters) error { return nil }, say: func(message string) {}, error: func(e error) {}, @@ -56,6 +58,7 @@ func TestStepValidateTemplateShouldTakeStepArgumentsFromStateBag(t *testing.T) { var actualTemplateParameters *TemplateParameters var testSubject = &StepValidateTemplate{ + template: Linux, validate: func(resourceGroupName string, deploymentName string, templateParameter *TemplateParameters) error { actualResourceGroupName = resourceGroupName actualDeploymentName = deploymentName diff --git a/builder/azure/arm/template.go b/builder/azure/arm/template.go index af157c92b..3bab5f5ea 100644 --- a/builder/azure/arm/template.go +++ b/builder/azure/arm/template.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm @@ -27,13 +27,16 @@ const Linux = `{ "imageSku": { "type": "string" }, + "imageVersion": { + "type": "string" + }, "osDiskName": { "type": "string" }, "sshAuthorizedKey": { "type": "string" }, - "storageAccountName": { + "storageAccountBlobEndpoint": { "type": "string" }, "vmSize": { @@ -151,12 +154,12 @@ const Linux = `{ "publisher": "[parameters('imagePublisher')]", "offer": "[parameters('imageOffer')]", "sku": "[parameters('imageSku')]", - "version": "latest" + "version": "[parameters('imageVersion')]" }, "osDisk": { "name": "osdisk", "vhd": { - "uri": "[concat('http://',parameters('storageAccountName'),'.blob.core.windows.net/',variables('vmStorageAccountContainerName'),'/', parameters('osDiskName'),'.vhd')]" + "uri": "[concat(parameters('storageAccountBlobEndpoint'),variables('vmStorageAccountContainerName'),'/', parameters('osDiskName'),'.vhd')]" }, "caching": "ReadWrite", "createOption": "FromImage" @@ -178,3 +181,309 @@ const Linux = `{ } ] }` + +// Template to deploy a KeyVault. +// +// NOTE: the parameters for the KeyVault template are identical to Windows +// template. Keeping these values in sync simplifies the code at the expense +// of template bloat. This bloat may be addressed in the future. +const KeyVault = `{ + "$schema": "http://schema.management.azure.com/schemas/2014-04-01-preview/deploymentTemplate.json", + "contentVersion": "1.0.0.0", + "parameters": { + "adminUserName": { + "type": "string" + }, + "adminPassword": { + "type": "securestring" + }, + "dnsNameForPublicIP": { + "type": "string" + }, + "imageOffer": { + "type": "string" + }, + "imagePublisher": { + "type": "string" + }, + "imageSku": { + "type": "string" + }, + "imageVersion": { + "type": "string" + }, + "keyVaultName": { + "type": "string" + }, + "keyVaultSecretValue": { + "type": "securestring" + }, + "objectId": { + "type": "string" + }, + "osDiskName": { + "type": "string" + }, + "storageAccountBlobEndpoint": { + "type": "string" + }, + "tenantId": { + "type": "string" + }, + "vmName": { + "type": "string" + }, + "vmSize": { + "type": "string" + }, + "winRMCertificateUrl": { + "type": "string" + } + }, + "variables": { + "apiVersion": "2015-06-01", + "location": "[resourceGroup().location]", + "keyVaultSecretName": "packerKeyVaultSecret" + }, + "resources": [ + { + "apiVersion": "[variables('apiVersion')]", + "type": "Microsoft.KeyVault/vaults", + "name": "[parameters('keyVaultName')]", + "location": "[variables('location')]", + "properties": { + "enabledForDeployment": "true", + "enabledForTemplateDeployment": "true", + "tenantId": "[parameters('tenantId')]", + "accessPolicies": [ + { + "tenantId": "[parameters('tenantId')]", + "objectId": "[parameters('objectId')]", + "permissions": { + "keys": [ "all" ], + "secrets": [ "all" ] + } + } + ], + "sku": { + "name": "standard", + "family": "A" + } + }, + "resources": [ + { + "apiVersion": "[variables('apiVersion')]", + "type": "secrets", + "name": "[variables('keyVaultSecretName')]", + "dependsOn": [ + "[concat('Microsoft.KeyVault/vaults/', parameters('keyVaultName'))]" + ], + "properties": { + "value": "[parameters('keyVaultSecretValue')]" + } + } + ] + } + ] +}` + +const Windows = `{ + "$schema": "http://schema.management.azure.com/schemas/2014-04-01-preview/deploymentTemplate.json", + "contentVersion": "1.0.0.0", + "parameters": { + "adminUserName": { + "type": "string" + }, + "adminPassword": { + "type": "securestring" + }, + "dnsNameForPublicIP": { + "type": "string" + }, + "imageOffer": { + "type": "string" + }, + "imagePublisher": { + "type": "string" + }, + "imageSku": { + "type": "string" + }, + "imageVersion": { + "type": "string" + }, + "keyVaultName": { + "type": "string" + }, + "keyVaultSecretValue": { + "type": "securestring" + }, + "objectId": { + "type": "string" + }, + "osDiskName": { + "type": "string" + }, + "storageAccountBlobEndpoint": { + "type": "string" + }, + "tenantId": { + "type": "string" + }, + "vmName": { + "type": "string" + }, + "vmSize": { + "type": "string" + }, + "winRMCertificateUrl": { + "type": "string" + } + }, + "variables": { + "addressPrefix": "10.0.0.0/16", + "apiVersion": "2015-06-15", + "location": "[resourceGroup().location]", + "nicName": "packerNic", + "publicIPAddressName": "packerPublicIP", + "publicIPAddressType": "Dynamic", + "subnetName": "packerSubnet", + "subnetAddressPrefix": "10.0.0.0/24", + "subnetRef": "[concat(variables('vnetID'),'/subnets/',variables('subnetName'))]", + "virtualNetworkName": "packerNetwork", + "vmStorageAccountContainerName": "images", + "vnetID": "[resourceId('Microsoft.Network/virtualNetworks', variables('virtualNetworkName'))]" + }, + "resources": [ + { + "apiVersion": "[variables('apiVersion')]", + "type": "Microsoft.Network/publicIPAddresses", + "name": "[variables('publicIPAddressName')]", + "location": "[variables('location')]", + "properties": { + "publicIPAllocationMethod": "[variables('publicIPAddressType')]", + "dnsSettings": { + "domainNameLabel": "[parameters('dnsNameForPublicIP')]" + } + } + }, + { + "apiVersion": "[variables('apiVersion')]", + "type": "Microsoft.Network/virtualNetworks", + "name": "[variables('virtualNetworkName')]", + "location": "[variables('location')]", + "properties": { + "addressSpace": { + "addressPrefixes": [ + "[variables('addressPrefix')]" + ] + }, + "subnets": [ + { + "name": "[variables('subnetName')]", + "properties": { + "addressPrefix": "[variables('subnetAddressPrefix')]" + } + } + ] + } + }, + { + "apiVersion": "[variables('apiVersion')]", + "type": "Microsoft.Network/networkInterfaces", + "name": "[variables('nicName')]", + "location": "[variables('location')]", + "dependsOn": [ + "[concat('Microsoft.Network/publicIPAddresses/', variables('publicIPAddressName'))]", + "[concat('Microsoft.Network/virtualNetworks/', variables('virtualNetworkName'))]" + ], + "properties": { + "ipConfigurations": [ + { + "name": "ipconfig", + "properties": { + "privateIPAllocationMethod": "Dynamic", + "publicIPAddress": { + "id": "[resourceId('Microsoft.Network/publicIPAddresses', variables('publicIPAddressName'))]" + }, + "subnet": { + "id": "[variables('subnetRef')]" + } + } + } + ] + } + }, + { + "apiVersion": "[variables('apiVersion')]", + "type": "Microsoft.Compute/virtualMachines", + "name": "[parameters('vmName')]", + "location": "[variables('location')]", + "dependsOn": [ + "[concat('Microsoft.Network/networkInterfaces/', variables('nicName'))]" + ], + "properties": { + "hardwareProfile": { + "vmSize": "[parameters('vmSize')]" + }, + "osProfile": { + "computerName": "[parameters('vmName')]", + "adminUsername": "[parameters('adminUsername')]", + "adminPassword": "[parameters('adminPassword')]", + "secrets": [ + { + "sourceVault": { + "id": "[resourceId(resourceGroup().name, 'Microsoft.KeyVault/vaults', parameters('keyVaultName'))]" + }, + "vaultCertificates": [ + { + "certificateUrl": "[parameters('winRMCertificateUrl')]", + "certificateStore": "My" + } + ] + } + ], + "windowsConfiguration": { + "provisionVMAgent": "true", + "winRM": { + "listeners": [{ + "protocol": "https", + "certificateUrl": "[parameters('winRMCertificateUrl')]" + } + ] + }, + "enableAutomaticUpdates": "true" + } + }, + "storageProfile": { + "imageReference": { + "publisher": "[parameters('imagePublisher')]", + "offer": "[parameters('imageOffer')]", + "sku": "[parameters('imageSku')]", + "version": "[parameters('imageVersion')]" + }, + "osDisk": { + "name": "osdisk", + "vhd": { + "uri": "[concat(parameters('storageAccountBlobEndpoint'),variables('vmStorageAccountContainerName'),'/', parameters('osDiskName'),'.vhd')]" + }, + "caching": "ReadWrite", + "createOption": "FromImage" + } + }, + "networkProfile": { + "networkInterfaces": [ + { + "id": "[resourceId('Microsoft.Network/networkInterfaces', variables('nicName'))]" + } + ] + }, + "diagnosticsProfile": { + "bootDiagnostics": { + "enabled": "false" + } + } + } + } + ] +}` diff --git a/builder/azure/arm/template_parameters.go b/builder/azure/arm/template_parameters.go index 195f51a0c..489cb0a76 100644 --- a/builder/azure/arm/template_parameters.go +++ b/builder/azure/arm/template_parameters.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm @@ -21,15 +21,21 @@ type TemplateParameter struct { } type TemplateParameters struct { - AdminUsername *TemplateParameter `json:"adminUsername,omitempty"` - AdminPassword *TemplateParameter `json:"adminPassword,omitempty"` - DnsNameForPublicIP *TemplateParameter `json:"dnsNameForPublicIP,omitempty"` - ImageOffer *TemplateParameter `json:"imageOffer,omitempty"` - ImagePublisher *TemplateParameter `json:"imagePublisher,omitempty"` - ImageSku *TemplateParameter `json:"imageSku,omitempty"` - OSDiskName *TemplateParameter `json:"osDiskName,omitempty"` - SshAuthorizedKey *TemplateParameter `json:"sshAuthorizedKey,omitempty"` - StorageAccountName *TemplateParameter `json:"storageAccountName,omitempty"` - VMSize *TemplateParameter `json:"vmSize,omitempty"` - VMName *TemplateParameter `json:"vmName,omitempty"` + AdminUsername *TemplateParameter `json:"adminUsername,omitempty"` + AdminPassword *TemplateParameter `json:"adminPassword,omitempty"` + DnsNameForPublicIP *TemplateParameter `json:"dnsNameForPublicIP,omitempty"` + ImageOffer *TemplateParameter `json:"imageOffer,omitempty"` + ImagePublisher *TemplateParameter `json:"imagePublisher,omitempty"` + ImageSku *TemplateParameter `json:"imageSku,omitempty"` + ImageVersion *TemplateParameter `json:"imageVersion,omitempty"` + KeyVaultName *TemplateParameter `json:"keyVaultName,omitempty"` + KeyVaultSecretValue *TemplateParameter `json:"keyVaultSecretValue,omitempty"` + ObjectId *TemplateParameter `json:"objectId,omitempty"` + OSDiskName *TemplateParameter `json:"osDiskName,omitempty"` + SshAuthorizedKey *TemplateParameter `json:"sshAuthorizedKey,omitempty"` + StorageAccountBlobEndpoint *TemplateParameter `json:"storageAccountBlobEndpoint,omitempty"` + TenantId *TemplateParameter `json:"tenantId,omitempty"` + VMSize *TemplateParameter `json:"vmSize,omitempty"` + VMName *TemplateParameter `json:"vmName,omitempty"` + WinRMCertificateUrl *TemplateParameter `json:"winRMCertificateUrl,omitempty"` } diff --git a/builder/azure/arm/template_parameters_test.go b/builder/azure/arm/template_parameters_test.go index 910216496..ce5491350 100644 --- a/builder/azure/arm/template_parameters_test.go +++ b/builder/azure/arm/template_parameters_test.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm @@ -12,17 +12,17 @@ import ( func TestTemplateParametersShouldHaveExpectedKeys(t *testing.T) { params := TemplateParameters{ - AdminUsername: &TemplateParameter{"sentinel"}, - AdminPassword: &TemplateParameter{"sentinel"}, - DnsNameForPublicIP: &TemplateParameter{"sentinel"}, - ImageOffer: &TemplateParameter{"sentinel"}, - ImagePublisher: &TemplateParameter{"sentinel"}, - ImageSku: &TemplateParameter{"sentinel"}, - OSDiskName: &TemplateParameter{"sentinel"}, - SshAuthorizedKey: &TemplateParameter{"sentinel"}, - StorageAccountName: &TemplateParameter{"sentinel"}, - VMName: &TemplateParameter{"sentinel"}, - VMSize: &TemplateParameter{"sentinel"}, + AdminUsername: &TemplateParameter{"sentinel"}, + AdminPassword: &TemplateParameter{"sentinel"}, + DnsNameForPublicIP: &TemplateParameter{"sentinel"}, + ImageOffer: &TemplateParameter{"sentinel"}, + ImagePublisher: &TemplateParameter{"sentinel"}, + ImageSku: &TemplateParameter{"sentinel"}, + OSDiskName: &TemplateParameter{"sentinel"}, + SshAuthorizedKey: &TemplateParameter{"sentinel"}, + StorageAccountBlobEndpoint: &TemplateParameter{"sentinel"}, + VMName: &TemplateParameter{"sentinel"}, + VMSize: &TemplateParameter{"sentinel"}, } bs, err := json.Marshal(params) @@ -46,7 +46,7 @@ func TestTemplateParametersShouldHaveExpectedKeys(t *testing.T) { "imageSku", "osDiskName", "sshAuthorizedKey", - "storageAccountName", + "storageAccountBlobEndpoint", "vmSize", "vmName", } @@ -61,17 +61,17 @@ func TestTemplateParametersShouldHaveExpectedKeys(t *testing.T) { func TestParameterValuesShouldBeSet(t *testing.T) { params := TemplateParameters{ - AdminUsername: &TemplateParameter{"adminusername00"}, - AdminPassword: &TemplateParameter{"adminpassword00"}, - DnsNameForPublicIP: &TemplateParameter{"dnsnameforpublicip00"}, - ImageOffer: &TemplateParameter{"imageoffer00"}, - ImagePublisher: &TemplateParameter{"imagepublisher00"}, - ImageSku: &TemplateParameter{"imagesku00"}, - OSDiskName: &TemplateParameter{"osdiskname00"}, - SshAuthorizedKey: &TemplateParameter{"sshauthorizedkey00"}, - StorageAccountName: &TemplateParameter{"storageaccountname00"}, - VMName: &TemplateParameter{"vmname00"}, - VMSize: &TemplateParameter{"vmsize00"}, + AdminUsername: &TemplateParameter{"adminusername00"}, + AdminPassword: &TemplateParameter{"adminpassword00"}, + DnsNameForPublicIP: &TemplateParameter{"dnsnameforpublicip00"}, + ImageOffer: &TemplateParameter{"imageoffer00"}, + ImagePublisher: &TemplateParameter{"imagepublisher00"}, + ImageSku: &TemplateParameter{"imagesku00"}, + OSDiskName: &TemplateParameter{"osdiskname00"}, + SshAuthorizedKey: &TemplateParameter{"sshauthorizedkey00"}, + StorageAccountBlobEndpoint: &TemplateParameter{"storageaccountblobendpoint00"}, + VMName: &TemplateParameter{"vmname00"}, + VMSize: &TemplateParameter{"vmsize00"}, } bs, err := json.Marshal(params) diff --git a/builder/azure/arm/tempname.go b/builder/azure/arm/tempname.go index 959d4d9a5..f5f577962 100644 --- a/builder/azure/arm/tempname.go +++ b/builder/azure/arm/tempname.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm @@ -15,11 +15,13 @@ const ( ) type TempName struct { - AdminPassword string - ComputeName string - DeploymentName string - ResourceGroupName string - OSDiskName string + AdminPassword string + CertificatePassword string + ComputeName string + DeploymentName string + KeyVaultName string + ResourceGroupName string + OSDiskName string } func NewTempName() *TempName { @@ -28,10 +30,12 @@ func NewTempName() *TempName { suffix := common.RandomString(TempNameAlphabet, 10) tempName.ComputeName = fmt.Sprintf("pkrvm%s", suffix) tempName.DeploymentName = fmt.Sprintf("pkrdp%s", suffix) + tempName.KeyVaultName = fmt.Sprintf("pkrkv%s", suffix) tempName.OSDiskName = fmt.Sprintf("pkros%s", suffix) tempName.ResourceGroupName = fmt.Sprintf("packer-Resource-Group-%s", suffix) tempName.AdminPassword = common.RandomString(TempPasswordAlphabet, 32) + tempName.CertificatePassword = common.RandomString(TempPasswordAlphabet, 32) return tempName } diff --git a/builder/azure/arm/tempname_test.go b/builder/azure/arm/tempname_test.go index c812e6c7d..114ecedf3 100644 --- a/builder/azure/arm/tempname_test.go +++ b/builder/azure/arm/tempname_test.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package arm diff --git a/builder/azure/common/constants/goos.go b/builder/azure/common/constants/goos.go index f3d957fa7..b904bf9d2 100644 --- a/builder/azure/common/constants/goos.go +++ b/builder/azure/common/constants/goos.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package constants diff --git a/builder/azure/common/constants/stateBag.go b/builder/azure/common/constants/stateBag.go index 05a0f6711..e7ca98c63 100644 --- a/builder/azure/common/constants/stateBag.go +++ b/builder/azure/common/constants/stateBag.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package constants @@ -29,12 +29,18 @@ const ( Ui string = "ui" ) const ( + ArmBlobEndpoint string = "arm.BlobEndpoint" + ArmCaptureTemplate string = "arm.CaptureTemplate" ArmComputeName string = "arm.ComputeName" + ArmCertificateUrl string = "arm.CertificateUrl" ArmDeploymentName string = "arm.DeploymentName" + ArmKeyVaultName string = "arm.KeyVaultName" ArmLocation string = "arm.Location" ArmOSDiskVhd string = "arm.OSDiskVhd" ArmPublicIPAddressName string = "arm.PublicIPAddressName" ArmResourceGroupName string = "arm.ResourceGroupName" + ArmIsResourceGroupCreated string = "arm.IsResourceGroupCreated" + ArmStorageAccountName string = "arm.StorageAccountName" ArmTemplateParameters string = "arm.TemplateParameters" ArmVirtualMachineCaptureParameters string = "arm.VirtualMachineCaptureParameters" ) diff --git a/builder/azure/common/constants/targetplatforms.go b/builder/azure/common/constants/targetplatforms.go index 953f3d320..d95f7b66c 100644 --- a/builder/azure/common/constants/targetplatforms.go +++ b/builder/azure/common/constants/targetplatforms.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package constants diff --git a/builder/azure/common/devicelogin.go b/builder/azure/common/devicelogin.go new file mode 100644 index 000000000..37725894c --- /dev/null +++ b/builder/azure/common/devicelogin.go @@ -0,0 +1,247 @@ +package common + +import ( + "fmt" + "net/http" + "os" + "path/filepath" + "regexp" + + "github.com/Azure/azure-sdk-for-go/arm/resources/subscriptions" + "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/azure" + "github.com/Azure/go-autorest/autorest/to" + "github.com/mitchellh/packer/version" + "github.com/mitchellh/go-homedir" +) + +var ( + // AD app id for packer-azure driver. + clientIDs = map[string]string{ + azure.PublicCloud.Name: "04cc58ec-51ab-4833-ac0d-ce3a7912414b", + } + + userAgent = fmt.Sprintf("packer/%s", version.FormattedVersion()) +) + +// NOTE(ahmetalpbalkan): Azure Active Directory implements OAuth 2.0 Device Flow +// described here: https://tools.ietf.org/html/draft-denniss-oauth-device-flow-00 +// Although it has some gotchas, most of the authentication logic is in Azure SDK +// for Go helper packages. +// +// Device auth prints a message to the screen telling the user to click on URL +// and approve the app on the browser, meanwhile the client polls the auth API +// for a token. Once we have token, we save it locally to a file with proper +// permissions and when the token expires (in Azure case typically 1 hour) SDK +// will automatically refresh the specified token and will call the refresh +// callback function we implement here. This way we will always be storing a +// token with a refresh_token saved on the machine. + +// Authenticate fetches a token from the local file cache or initiates a consent +// flow and waits for token to be obtained. +func Authenticate(env azure.Environment, subscriptionID string, say func(string)) (*azure.ServicePrincipalToken, error) { + clientID, ok := clientIDs[env.Name] + if !ok { + return nil, fmt.Errorf("packer-azure application not set up for Azure environment %q", env.Name) + } + + // First we locate the tenant ID of the subscription as we store tokens per + // tenant (which could have multiple subscriptions) + say(fmt.Sprintf("Looking up AAD Tenant ID: subscriptionID=%s.", subscriptionID)) + tenantID, err := findTenantID(env, subscriptionID) + if err != nil { + return nil, err + } + say(fmt.Sprintf("Found AAD Tenant ID: tenantID=%s", tenantID)) + + oauthCfg, err := env.OAuthConfigForTenant(tenantID) + if err != nil { + return nil, fmt.Errorf("Failed to obtain oauth config for azure environment: %v", err) + } + + // for AzurePublicCloud (https://management.core.windows.net/), this old + // Service Management scope covers both ASM and ARM. + apiScope := env.ServiceManagementEndpoint + + tokenPath := tokenCachePath(tenantID) + saveToken := mkTokenCallback(tokenPath) + saveTokenCallback := func(t azure.Token) error { + say("Azure token expired. Saving the refreshed token...") + return saveToken(t) + } + + // Lookup the token cache file for an existing token. + spt, err := tokenFromFile(say, *oauthCfg, tokenPath, clientID, apiScope, saveTokenCallback) + if err != nil { + return nil, err + } + if spt != nil { + say(fmt.Sprintf("Auth token found in file: %s", tokenPath)) + + // NOTE(ahmetalpbalkan): The token file we found may contain an + // expired access_token. In that case, the first call to Azure SDK will + // attempt to refresh the token using refresh_token, which might have + // expired[1], in that case we will get an error and we shall remove the + // token file and initiate token flow again so that the user would not + // need removing the token cache file manually. + // + // [1]: expiration date of refresh_token is not returned in AAD /token + // response, we just know it is 14 days. Therefore user’s token + // will go stale every 14 days and we will delete the token file, + // re-initiate the device flow. + say("Validating the token.") + if err := validateToken(env, spt); err != nil { + say(fmt.Sprintf("Error: %v", err)) + say("Stored Azure credentials expired. Please reauthenticate.") + say(fmt.Sprintf("Deleting %s", tokenPath)) + if err := os.RemoveAll(tokenPath); err != nil { + return nil, fmt.Errorf("Error deleting stale token file: %v", err) + } + } else { + say("Token works.") + return spt, nil + } + } + + // Start an OAuth 2.0 device flow + say(fmt.Sprintf("Initiating device flow: %s", tokenPath)) + spt, err = tokenFromDeviceFlow(say, *oauthCfg, tokenPath, clientID, apiScope) + if err != nil { + return nil, err + } + say("Obtained service principal token.") + if err := saveToken(spt.Token); err != nil { + say("Error occurred saving token to cache file.") + return nil, err + } + return spt, nil +} + +// tokenFromFile returns a token from the specified file if it is found, otherwise +// returns nil. Any error retrieving or creating the token is returned as an error. +func tokenFromFile(say func(string), oauthCfg azure.OAuthConfig, tokenPath, clientID, resource string, + callback azure.TokenRefreshCallback) (*azure.ServicePrincipalToken, error) { + say(fmt.Sprintf("Loading auth token from file: %s", tokenPath)) + if _, err := os.Stat(tokenPath); err != nil { + if os.IsNotExist(err) { // file not found + return nil, nil + } + return nil, err + } + + token, err := azure.LoadToken(tokenPath) + if err != nil { + return nil, fmt.Errorf("Failed to load token from file: %v", err) + } + + spt, err := azure.NewServicePrincipalTokenFromManualToken(oauthCfg, clientID, resource, *token, callback) + if err != nil { + return nil, fmt.Errorf("Error constructing service principal token: %v", err) + } + return spt, nil +} + +// tokenFromDeviceFlow prints a message to the screen for user to take action to +// consent application on a browser and in the meanwhile the authentication +// endpoint is polled until user gives consent, denies or the flow times out. +// Returned token must be saved. +func tokenFromDeviceFlow(say func(string), oauthCfg azure.OAuthConfig, tokenPath, clientID, resource string) (*azure.ServicePrincipalToken, error) { + cl := autorest.NewClientWithUserAgent(userAgent) + deviceCode, err := azure.InitiateDeviceAuth(&cl, oauthCfg, clientID, resource) + if err != nil { + return nil, fmt.Errorf("Failed to start device auth: %v", err) + } + + // Example message: “To sign in, open https://aka.ms/devicelogin and enter + // the code 0000000 to authenticate.” + say(fmt.Sprintf("Microsoft Azure: %s", to.String(deviceCode.Message))) + + token, err := azure.WaitForUserCompletion(&cl, deviceCode) + if err != nil { + return nil, fmt.Errorf("Failed to complete device auth: %v", err) + } + + spt, err := azure.NewServicePrincipalTokenFromManualToken(oauthCfg, clientID, resource, *token) + if err != nil { + return nil, fmt.Errorf("Error constructing service principal token: %v", err) + } + return spt, nil +} + +// tokenCachePath returns the full path the OAuth 2.0 token should be saved at +// for given tenant ID. +func tokenCachePath(tenantID string) string { + dir, err := homedir.Dir() + if err != nil { + dir, _ = filepath.Abs(os.Args[0]) + } + + return filepath.Join(dir, ".azure", "packer", fmt.Sprintf("oauth-%s.json", tenantID)) +} + +// mkTokenCallback returns a callback function that can be used to save the +// token initially or register to the Azure SDK to be called when the token is +// refreshed. +func mkTokenCallback(path string) azure.TokenRefreshCallback { + return func(t azure.Token) error { + if err := azure.SaveToken(path, 0600, t); err != nil { + return err + } + return nil + } +} + +// validateToken makes a call to Azure SDK with given token, essentially making +// sure if the access_token valid, if not it uses SDK’s functionality to +// automatically refresh the token using refresh_token (which might have +// expired). This check is essentially to make sure refresh_token is good. +func validateToken(env azure.Environment, token *azure.ServicePrincipalToken) error { + c := subscriptionsClient(env.ResourceManagerEndpoint) + // WTF(chrboum) + //c.Authorizer = token + _, err := c.List() + if err != nil { + return fmt.Errorf("Token validity check failed: %v", err) + } + return nil +} + +// findTenantID figures out the AAD tenant ID of the subscription by making an +// unauthenticated request to the Get Subscription Details endpoint and parses +// the value from WWW-Authenticate header. +func findTenantID(env azure.Environment, subscriptionID string) (string, error) { + const hdrKey = "WWW-Authenticate" + c := subscriptionsClient(env.ResourceManagerEndpoint) + + // we expect this request to fail (err != nil), but we are only interested + // in headers, so surface the error if the Response is not present (i.e. + // network error etc) + subs, err := c.Get(subscriptionID) + if subs.Response.Response == nil { + return "", fmt.Errorf("Request failed: %v", err) + } + + // Expecting 401 StatusUnauthorized here, just read the header + if subs.StatusCode != http.StatusUnauthorized { + return "", fmt.Errorf("Unexpected response from Get Subscription: %v", err) + } + hdr := subs.Header.Get(hdrKey) + if hdr == "" { + return "", fmt.Errorf("Header %v not found in Get Subscription response", hdrKey) + } + + // Example value for hdr: + // Bearer authorization_uri="https://login.windows.net/996fe9d1-6171-40aa-945b-4c64b63bf655", error="invalid_token", error_description="The authentication failed because of missing 'Authorization' header." + r := regexp.MustCompile(`authorization_uri=".*/([0-9a-f\-]+)"`) + m := r.FindStringSubmatch(hdr) + if m == nil { + return "", fmt.Errorf("Could not find the tenant ID in header: %s %q", hdrKey, hdr) + } + return m[1], nil +} + +func subscriptionsClient(baseURI string) subscriptions.Client { + c := subscriptions.NewClientWithBaseURI(baseURI, "") // used only for unauthenticated requests for generic subs IDs + c.Client.UserAgent += userAgent + return c +} diff --git a/builder/azure/common/gluestrings.go b/builder/azure/common/gluestrings.go index 0988b98a9..2bab92e54 100644 --- a/builder/azure/common/gluestrings.go +++ b/builder/azure/common/gluestrings.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package common diff --git a/builder/azure/common/gluestrings_test.go b/builder/azure/common/gluestrings_test.go index 63379d475..ac4c1528d 100644 --- a/builder/azure/common/gluestrings_test.go +++ b/builder/azure/common/gluestrings_test.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package common diff --git a/builder/azure/common/interruptible_task.go b/builder/azure/common/interruptible_task.go new file mode 100644 index 000000000..31915dbce --- /dev/null +++ b/builder/azure/common/interruptible_task.go @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. + +package common + +import ( + "time" +) + +type InterruptibleTaskResult struct { + Err error + IsCancelled bool +} + +type InterruptibleTask struct { + IsCancelled func() bool + Task func(cancelCh <-chan struct{}) error +} + +func NewInterruptibleTask(isCancelled func() bool, task func(cancelCh <-chan struct{}) error) *InterruptibleTask { + return &InterruptibleTask{ + IsCancelled: isCancelled, + Task: task, + } +} + +func StartInterruptibleTask(isCancelled func() bool, task func(cancelCh <-chan struct{}) error) InterruptibleTaskResult { + t := NewInterruptibleTask(isCancelled, task) + return t.Run() +} + +func (s *InterruptibleTask) Run() InterruptibleTaskResult { + completeCh := make(chan error) + + cancelCh := make(chan struct{}) + defer close(cancelCh) + + go func() { + err := s.Task(cancelCh) + completeCh <- err + + // senders close, receivers check for close + close(completeCh) + }() + + for { + if s.IsCancelled() { + return InterruptibleTaskResult{Err: nil, IsCancelled: true} + } + + select { + case err := <-completeCh: + return InterruptibleTaskResult{Err: err, IsCancelled: false} + case <-time.After(100 * time.Millisecond): + } + } +} diff --git a/builder/azure/common/interruptible_task_test.go b/builder/azure/common/interruptible_task_test.go new file mode 100644 index 000000000..b11e85e12 --- /dev/null +++ b/builder/azure/common/interruptible_task_test.go @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. + +package common + +import ( + "fmt" + "testing" + "time" +) + +func TestInterruptibleTaskShouldImmediatelyEndOnCancel(t *testing.T) { + testSubject := NewInterruptibleTask( + func() bool { return true }, + func(<-chan struct{}) error { + for { + time.Sleep(time.Second * 30) + } + }) + + result := testSubject.Run() + if result.IsCancelled != true { + t.Fatalf("Expected the task to be cancelled, but it was not.") + } +} + +func TestInterruptibleTaskShouldRunTaskUntilCompletion(t *testing.T) { + var count int + + testSubject := &InterruptibleTask{ + IsCancelled: func() bool { + return false + }, + Task: func(<-chan struct{}) error { + for i := 0; i < 10; i++ { + count += 1 + } + + return nil + }, + } + + result := testSubject.Run() + if result.IsCancelled != false { + t.Errorf("Expected the task to *not* be cancelled, but it was.") + } + + if count != 10 { + t.Errorf("Expected the task to wait for completion, but it did not.") + } + + if result.Err != nil { + t.Errorf("Expected the task to return a nil error, but got=%s", result.Err) + } +} + +func TestInterruptibleTaskShouldImmediatelyStopOnTaskError(t *testing.T) { + testSubject := &InterruptibleTask{ + IsCancelled: func() bool { + return false + }, + Task: func(cancelCh <-chan struct{}) error { + return fmt.Errorf("boom") + }, + } + + result := testSubject.Run() + if result.IsCancelled != false { + t.Errorf("Expected the task to *not* be cancelled, but it was.") + } + + if result.Err == nil { + t.Errorf("Expected the task to return an error, but it did not.") + } +} + +func TestInterruptibleTaskShouldProvideLiveChannel(t *testing.T) { + testSubject := &InterruptibleTask{ + IsCancelled: func() bool { + return false + }, + Task: func(cancelCh <-chan struct{}) error { + isOpen := false + + select { + case _, ok := <-cancelCh: + isOpen = !ok + if !isOpen { + t.Errorf("Expected the channel to open, but it was closed.") + } + default: + isOpen = true + break + } + + if !isOpen { + t.Errorf("Check for openness failed.") + } + + return nil + }, + } + + testSubject.Run() +} diff --git a/builder/azure/common/lin/ssh.go b/builder/azure/common/lin/ssh.go index 6f98254c2..b6d5a6efe 100644 --- a/builder/azure/common/lin/ssh.go +++ b/builder/azure/common/lin/ssh.go @@ -1,13 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package lin import ( "fmt" - - "github.com/mitchellh/multistep" "github.com/mitchellh/packer/builder/azure/common/constants" + "github.com/mitchellh/multistep" "golang.org/x/crypto/ssh" ) diff --git a/builder/azure/common/lin/step_create_cert.go b/builder/azure/common/lin/step_create_cert.go index 45e1dc409..5ba38ac43 100644 --- a/builder/azure/common/lin/step_create_cert.go +++ b/builder/azure/common/lin/step_create_cert.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package lin @@ -28,11 +28,11 @@ type StepCreateCert struct { func (s *StepCreateCert) Run(state multistep.StateBag) multistep.StepAction { ui := state.Get("ui").(packer.Ui) - ui.Say("Creating Temporary Certificate...") + ui.Say("Creating temporary certificate...") err := s.createCert(state) if err != nil { - err := fmt.Errorf("Error Creating Temporary Certificate: %s", err) + err := fmt.Errorf("Error creating temporary certificate: %s", err) state.Put("error", err) ui.Error(err.Error()) return multistep.ActionHalt diff --git a/builder/azure/common/lin/step_generalize_os.go b/builder/azure/common/lin/step_generalize_os.go index 24dcad830..508de67b4 100644 --- a/builder/azure/common/lin/step_generalize_os.go +++ b/builder/azure/common/lin/step_generalize_os.go @@ -1,15 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package lin import ( "bytes" "fmt" - "log" - "github.com/mitchellh/multistep" "github.com/mitchellh/packer/packer" + "log" ) type StepGeneralizeOS struct { diff --git a/builder/azure/common/logutil/logfields.go b/builder/azure/common/logutil/logfields.go new file mode 100644 index 000000000..664ee7196 --- /dev/null +++ b/builder/azure/common/logutil/logfields.go @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. + +package logutil + +import "fmt" + +type Fields map[string]interface{} + +func (f Fields) String() string { + var s string + for k, v := range f { + if sv, ok := v.(string); ok { + v = fmt.Sprintf("%q", sv) + } + s += fmt.Sprintf(" %s=%v", k, v) + } + return s +} diff --git a/builder/azure/common/randomstring.go b/builder/azure/common/randomstring.go index 6d8d752d8..eab3515dd 100644 --- a/builder/azure/common/randomstring.go +++ b/builder/azure/common/randomstring.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package common diff --git a/builder/azure/common/randomstring_test.go b/builder/azure/common/randomstring_test.go index ce375e7b6..6b0c3efc5 100644 --- a/builder/azure/common/randomstring_test.go +++ b/builder/azure/common/randomstring_test.go @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See the LICENSE file in builder/azure for license information. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. package common diff --git a/builder/azure/common/state_bag.go b/builder/azure/common/state_bag.go new file mode 100644 index 000000000..d19a068bc --- /dev/null +++ b/builder/azure/common/state_bag.go @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. + +package common + +import "github.com/mitchellh/multistep" + +func IsStateCancelled(stateBag multistep.StateBag) bool { + _, ok := stateBag.GetOk(multistep.StateCancelled) + return ok +} diff --git a/builder/azure/common/vault.go b/builder/azure/common/vault.go new file mode 100644 index 000000000..1cf919795 --- /dev/null +++ b/builder/azure/common/vault.go @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See the LICENSE file in the project root for license information. + +// NOTE: vault APIs do not yet exist in the SDK, but once they do this code +// should be removed. + +package common + +import ( + "net/http" + "strings" + + "github.com/Azure/go-autorest/autorest" +) + +const ( + AzureVaultApiVersion = "2015-06-01" + AzureVaultScope = "https://vault.azure.net" + AzureVaultSecretTemplate = "https://{vault-name}.vault.azure.net/secrets/{secret-name}" +) + +type VaultClient struct { + autorest.Client +} + +type Secret struct { + ID *string `json:"id,omitempty"` + Value string `json:"value"` + Attributes SecretAttributes `json:"attributes"` +} + +type SecretAttributes struct { + Enabled bool `json:"enabled"` + Created *string `json:"created"` + Updated *string `json:"updated"` +} + +func (client *VaultClient) GetSecret(vaultName, secretName string) (*Secret, error) { + p := map[string]interface{}{ + "secret-name": secretName, + } + q := map[string]interface{}{ + "api-version": AzureVaultApiVersion, + } + + secretURL := strings.Replace(AzureVaultSecretTemplate, "{vault-name}", vaultName, -1) + + req, err := autorest.Prepare(&http.Request{}, + autorest.AsGet(), + autorest.WithBaseURL(secretURL), + autorest.WithPathParameters(p), + autorest.WithQueryParameters(q)) + + if err != nil { + return nil, err + } + + //resp, err := v.Send(req, http.StatusOK) + resp, err := autorest.SendWithSender(client, req) + if err != nil { + return nil, err + } + + var secret Secret + + err = autorest.Respond( + resp, + autorest.ByUnmarshallingJSON(&secret)) + if err != nil { + return nil, err + } + + return &secret, nil +} diff --git a/builder/azure/pkcs12/README.md b/builder/azure/pkcs12/README.md new file mode 100644 index 000000000..94fd56929 --- /dev/null +++ b/builder/azure/pkcs12/README.md @@ -0,0 +1,9 @@ +This is a fork of the from the original PKCS#12 parsing code +published in the Azure repository [go-pkcs12](https://github.com/Azure/go-pkcs12). +This fork adds serializing a x509 certificate and private key as PKCS#12 binary blob +(aka .PFX file). Due to the specific nature of this code it was not accepted for +inclusion in the official Go crypto repository. + +The methods used for decoding PKCS#12 have been moved to the test files to further +discourage the use of this library for decoding. Please use the official +[pkcs12](https://godoc.org/golang.org/x/crypto/pkcs12) library for decode support. \ No newline at end of file diff --git a/builder/azure/pkcs12/bmp-string.go b/builder/azure/pkcs12/bmp-string.go new file mode 100644 index 000000000..e98069b82 --- /dev/null +++ b/builder/azure/pkcs12/bmp-string.go @@ -0,0 +1,26 @@ +package pkcs12 + +import ( + "errors" + "unicode/utf16" +) + +func bmpString(s string) ([]byte, error) { + // References: + // https://tools.ietf.org/html/rfc7292#appendix-B.1 + // http://en.wikipedia.org/wiki/Plane_(Unicode)#Basic_Multilingual_Plane + // - non-BMP characters are encoded in UTF 16 by using a surrogate pair of 16-bit codes + // EncodeRune returns 0xfffd if the rune does not need special encoding + // - the above RFC provides the info that BMPStrings are NULL terminated. + + rv := make([]byte, 0, 2*len(s)+2) + + for _, r := range s { + if t, _ := utf16.EncodeRune(r); t != 0xfffd { + return nil, errors.New("string contains characters that cannot be encoded in UCS-2") + } + rv = append(rv, byte(r/256), byte(r%256)) + } + rv = append(rv, 0, 0) + return rv, nil +} diff --git a/builder/azure/pkcs12/bmp-string_test.go b/builder/azure/pkcs12/bmp-string_test.go new file mode 100644 index 000000000..85dd96fb4 --- /dev/null +++ b/builder/azure/pkcs12/bmp-string_test.go @@ -0,0 +1,70 @@ +package pkcs12 + +import ( + "bytes" + "errors" + "testing" + "unicode/utf16" +) + +func decodeBMPString(bmpString []byte) (string, error) { + if len(bmpString)%2 != 0 { + return "", errors.New("expected BMP byte string to be an even length") + } + + // strip terminator if present + if terminator := bmpString[len(bmpString)-2:]; terminator[0] == terminator[1] && terminator[1] == 0 { + bmpString = bmpString[:len(bmpString)-2] + } + + s := make([]uint16, 0, len(bmpString)/2) + for len(bmpString) > 0 { + s = append(s, uint16(bmpString[0])*265+uint16(bmpString[1])) + bmpString = bmpString[2:] + } + + return string(utf16.Decode(s)), nil +} + +func TestBMPStringDecode(t *testing.T) { + _, err := decodeBMPString([]byte("a")) + if err == nil { + t.Fatalf("expected decode to fail, but it succeeded") + } +} + +func TestBMPString(t *testing.T) { + str, err := bmpString("") + if bytes.Compare(str, []byte{0, 0}) != 0 { + t.Errorf("expected empty string to return double 0, but found: % x", str) + } + if err != nil { + t.Errorf("err: %v", err) + } + + // Example from https://tools.ietf.org/html/rfc7292#appendix-B + str, err = bmpString("Beavis") + if bytes.Compare(str, []byte{0x00, 0x42, 0x00, 0x65, 0x00, 0x61, 0x00, 0x0076, 0x00, 0x69, 0x00, 0x73, 0x00, 0x00}) != 0 { + t.Errorf("expected 'Beavis' to return 0x00 0x42 0x00 0x65 0x00 0x61 0x00 0x76 0x00 0x69 0x00 0x73 0x00 0x00, but found: % x", str) + } + if err != nil { + t.Errorf("err: %v", err) + } + + // some characters from the "Letterlike Symbols Unicode block" + tst := "\u2115 - Double-struck N" + str, err = bmpString(tst) + if bytes.Compare(str, []byte{0x21, 0x15, 0x00, 0x20, 0x00, 0x2d, 0x00, 0x20, 0x00, 0x44, 0x00, 0x6f, 0x00, 0x75, 0x00, 0x62, 0x00, 0x6c, 0x00, 0x65, 0x00, 0x2d, 0x00, 0x73, 0x00, 0x74, 0x00, 0x72, 0x00, 0x75, 0x00, 0x63, 0x00, 0x6b, 0x00, 0x20, 0x00, 0x4e, 0x00, 0x00}) != 0 { + t.Errorf("expected '%s' to return 0x21 0x15 0x00 0x20 0x00 0x2d 0x00 0x20 0x00 0x44 0x00 0x6f 0x00 0x75 0x00 0x62 0x00 0x6c 0x00 0x65 0x00 0x2d 0x00 0x73 0x00 0x74 0x00 0x72 0x00 0x75 0x00 0x63 0x00 0x6b 0x00 0x20 0x00 0x4e 0x00 0x00, but found: % x", tst, str) + } + if err != nil { + t.Errorf("err: %v", err) + } + + // some character outside the BMP should error + tst = "\U0001f000 East wind (Mahjong)" + str, err = bmpString(tst) + if err == nil { + t.Errorf("expected '%s' to throw error because the first character is not in the BMP", tst) + } +} diff --git a/builder/azure/pkcs12/crypto.go b/builder/azure/pkcs12/crypto.go new file mode 100644 index 000000000..17bca7d79 --- /dev/null +++ b/builder/azure/pkcs12/crypto.go @@ -0,0 +1,82 @@ +// implementation of https://tools.ietf.org/html/rfc2898#section-6.1.2 + +package pkcs12 + +import ( + "bytes" + "crypto/cipher" + "crypto/des" + "crypto/rand" + "crypto/x509/pkix" + "encoding/asn1" + "errors" + "io" + + "github.com/mitchellh/packer/builder/azure/pkcs12/rc2" +) + +const ( + pbeWithSHAAnd3KeyTripleDESCBC = "pbeWithSHAAnd3-KeyTripleDES-CBC" + pbewithSHAAnd40BitRC2CBC = "pbewithSHAAnd40BitRC2-CBC" +) + +const ( + pbeIterationCount = 2048 + pbeSaltSizeBytes = 8 +) + +var ( + oidPbeWithSHAAnd3KeyTripleDESCBC = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 12, 1, 3} + oidPbewithSHAAnd40BitRC2CBC = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 12, 1, 6} +) + +var algByOID = map[string]string{ + oidPbeWithSHAAnd3KeyTripleDESCBC.String(): pbeWithSHAAnd3KeyTripleDESCBC, + oidPbewithSHAAnd40BitRC2CBC.String(): pbewithSHAAnd40BitRC2CBC, +} + +var blockcodeByAlg = map[string]func(key []byte) (cipher.Block, error){ + pbeWithSHAAnd3KeyTripleDESCBC: des.NewTripleDESCipher, + pbewithSHAAnd40BitRC2CBC: func(key []byte) (cipher.Block, error) { + return rc2.New(key, len(key)*8) + }, +} + +type pbeParams struct { + Salt []byte + Iterations int +} + +func pad(src []byte, blockSize int) []byte { + paddingLength := blockSize - len(src)%blockSize + paddingText := bytes.Repeat([]byte{byte(paddingLength)}, paddingLength) + return append(src, paddingText...) +} + +func pbEncrypt(plainText, salt, password []byte, iterations int) (cipherText []byte, err error) { + _, err = io.ReadFull(rand.Reader, salt) + if err != nil { + return nil, errors.New("pkcs12: failed to create a random salt value: " + err.Error()) + } + + key := deriveKeyByAlg[pbeWithSHAAnd3KeyTripleDESCBC](salt, password, iterations) + iv := deriveIVByAlg[pbeWithSHAAnd3KeyTripleDESCBC](salt, password, iterations) + + block, err := des.NewTripleDESCipher(key) + if err != nil { + return nil, errors.New("pkcs12: failed to create a block cipher: " + err.Error()) + } + + paddedPlainText := pad(plainText, block.BlockSize()) + + encrypter := cipher.NewCBCEncrypter(block, iv) + cipherText = make([]byte, len(paddedPlainText)) + encrypter.CryptBlocks(cipherText, paddedPlainText) + + return cipherText, nil +} + +type decryptable interface { + GetAlgorithm() pkix.AlgorithmIdentifier + GetData() []byte +} diff --git a/builder/azure/pkcs12/crypto_test.go b/builder/azure/pkcs12/crypto_test.go new file mode 100644 index 000000000..3f40ba771 --- /dev/null +++ b/builder/azure/pkcs12/crypto_test.go @@ -0,0 +1,159 @@ +package pkcs12 + +import ( + "bytes" + "crypto/cipher" + "crypto/x509/pkix" + "encoding/asn1" + "testing" +) + +func pbDecrypterFor(algorithm pkix.AlgorithmIdentifier, password []byte) (cipher.BlockMode, error) { + algorithmName, supported := algByOID[algorithm.Algorithm.String()] + if !supported { + return nil, NotImplementedError("algorithm " + algorithm.Algorithm.String() + " is not supported") + } + + var params pbeParams + if _, err := asn1.Unmarshal(algorithm.Parameters.FullBytes, ¶ms); err != nil { + return nil, err + } + + k := deriveKeyByAlg[algorithmName](params.Salt, password, params.Iterations) + iv := deriveIVByAlg[algorithmName](params.Salt, password, params.Iterations) + password = nil + + code, err := blockcodeByAlg[algorithmName](k) + if err != nil { + return nil, err + } + + cbc := cipher.NewCBCDecrypter(code, iv) + return cbc, nil +} + +func pbDecrypt(info decryptable, password []byte) (decrypted []byte, err error) { + cbc, err := pbDecrypterFor(info.GetAlgorithm(), password) + password = nil + if err != nil { + return nil, err + } + + encrypted := info.GetData() + + decrypted = make([]byte, len(encrypted)) + cbc.CryptBlocks(decrypted, encrypted) + + if psLen := int(decrypted[len(decrypted)-1]); psLen > 0 && psLen <= cbc.BlockSize() { + m := decrypted[:len(decrypted)-psLen] + ps := decrypted[len(decrypted)-psLen:] + if bytes.Compare(ps, bytes.Repeat([]byte{byte(psLen)}, psLen)) != 0 { + return nil, ErrDecryption + } + decrypted = m + } else { + return nil, ErrDecryption + } + + return +} + +func TestPbDecrypterFor(t *testing.T) { + params, _ := asn1.Marshal(pbeParams{ + Salt: []byte{1, 2, 3, 4, 5, 6, 7, 8}, + Iterations: 2048, + }) + alg := pkix.AlgorithmIdentifier{ + Algorithm: asn1.ObjectIdentifier([]int{1, 2, 3}), + Parameters: asn1.RawValue{ + FullBytes: params, + }, + } + + pass, _ := bmpString("Sesame open") + + _, err := pbDecrypterFor(alg, pass) + if _, ok := err.(NotImplementedError); !ok { + t.Errorf("expected not implemented error, got: %T %s", err, err) + } + + alg.Algorithm = asn1.ObjectIdentifier([]int{1, 2, 840, 113549, 1, 12, 1, 3}) + cbc, err := pbDecrypterFor(alg, pass) + if err != nil { + t.Errorf("err: %v", err) + } + + M := []byte{1, 2, 3, 4, 5, 6, 7, 8} + expectedM := []byte{185, 73, 135, 249, 137, 1, 122, 247} + cbc.CryptBlocks(M, M) + + if bytes.Compare(M, expectedM) != 0 { + t.Errorf("expected M to be '%d', but found '%d", expectedM, M) + } +} + +func TestPbDecrypt(t *testing.T) { + + tests := [][]byte{ + []byte("\x33\x73\xf3\x9f\xda\x49\xae\xfc\xa0\x9a\xdf\x5a\x58\xa0\xea\x46"), // 7 padding bytes + []byte("\x33\x73\xf3\x9f\xda\x49\xae\xfc\x96\x24\x2f\x71\x7e\x32\x3f\xe7"), // 8 padding bytes + []byte("\x35\x0c\xc0\x8d\xab\xa9\x5d\x30\x7f\x9a\xec\x6a\xd8\x9b\x9c\xd9"), // 9 padding bytes, incorrect + []byte("\xb2\xf9\x6e\x06\x60\xae\x20\xcf\x08\xa0\x7b\xd9\x6b\x20\xef\x41"), // incorrect padding bytes: [ ... 0x04 0x02 ] + } + expected := []interface{}{ + []byte("A secret!"), + []byte("A secret"), + ErrDecryption, + ErrDecryption, + } + + for i, c := range tests { + td := testDecryptable{ + data: c, + algorithm: pkix.AlgorithmIdentifier{ + Algorithm: asn1.ObjectIdentifier([]int{1, 2, 840, 113549, 1, 12, 1, 3}), // SHA1/3TDES + Parameters: pbeParams{ + Salt: []byte("\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8"), + Iterations: 4096, + }.RawASN1(), + }, + } + p, _ := bmpString("sesame") + + m, err := pbDecrypt(td, p) + + switch e := expected[i].(type) { + case []byte: + if err != nil { + t.Errorf("error decrypting C=%x: %v", c, err) + } + if bytes.Compare(m, e) != 0 { + t.Errorf("expected C=%x to be decoded to M=%x, but found %x", c, e, m) + } + case error: + if err == nil || err.Error() != e.Error() { + t.Errorf("expecting error '%v' during decryption of c=%x, but found err='%v'", e, c, err) + } + } + } +} + +type testDecryptable struct { + data []byte + algorithm pkix.AlgorithmIdentifier +} + +func (d testDecryptable) GetAlgorithm() pkix.AlgorithmIdentifier { return d.algorithm } +func (d testDecryptable) GetData() []byte { return d.data } + +func (params pbeParams) RawASN1() (raw asn1.RawValue) { + asn1Bytes, err := asn1.Marshal(params) + if err != nil { + panic(err) + } + _, err = asn1.Unmarshal(asn1Bytes, &raw) + if err != nil { + panic(err) + } + return +} diff --git a/builder/azure/pkcs12/errors.go b/builder/azure/pkcs12/errors.go new file mode 100644 index 000000000..64a9433ef --- /dev/null +++ b/builder/azure/pkcs12/errors.go @@ -0,0 +1,24 @@ +package pkcs12 + +import "errors" + +var ( + // ErrDecryption represents a failure to decrypt the input. + ErrDecryption = errors.New("pkcs12: decryption error, incorrect padding") + + // ErrIncorrectPassword is returned when an incorrect password is detected. + // Usually, P12/PFX data is signed to be able to verify the password. + ErrIncorrectPassword = errors.New("pkcs12: decryption password incorrect") +) + +// NotImplementedError indicates that the input is not currently supported. +type NotImplementedError string +type EncodeError string + +func (e NotImplementedError) Error() string { + return string(e) +} + +func (e EncodeError) Error() string { + return "pkcs12: encode error: " + string(e) +} diff --git a/builder/azure/pkcs12/mac.go b/builder/azure/pkcs12/mac.go new file mode 100644 index 000000000..c7e428117 --- /dev/null +++ b/builder/azure/pkcs12/mac.go @@ -0,0 +1,33 @@ +package pkcs12 + +import ( + "crypto/hmac" + "crypto/sha1" + "crypto/x509/pkix" + "encoding/asn1" +) + +var ( + oidSha1Algorithm = asn1.ObjectIdentifier{1, 3, 14, 3, 2, 26} +) + +type macData struct { + Mac digestInfo + MacSalt []byte + Iterations int `asn1:"optional,default:1"` +} + +// from PKCS#7: +type digestInfo struct { + Algorithm pkix.AlgorithmIdentifier + Digest []byte +} + +func computeMac(message []byte, iterations int, salt, password []byte) []byte { + key := pbkdf(sha1Sum, 20, 64, salt, password, iterations, 3, 20) + + mac := hmac.New(sha1.New, key) + mac.Write(message) + + return mac.Sum(nil) +} diff --git a/builder/azure/pkcs12/mac_test.go b/builder/azure/pkcs12/mac_test.go new file mode 100644 index 000000000..0ccea7104 --- /dev/null +++ b/builder/azure/pkcs12/mac_test.go @@ -0,0 +1,52 @@ +package pkcs12 + +import ( + "crypto/hmac" + "encoding/asn1" + "testing" +) + +func verifyMac(macData *macData, message, password []byte) error { + if !macData.Mac.Algorithm.Algorithm.Equal(oidSha1Algorithm) { + return NotImplementedError("unknown digest algorithm: " + macData.Mac.Algorithm.Algorithm.String()) + } + + expectedMAC := computeMac(message, macData.Iterations, macData.MacSalt, password) + + if !hmac.Equal(macData.Mac.Digest, expectedMAC) { + return ErrIncorrectPassword + } + return nil +} + +func TestVerifyMac(t *testing.T) { + td := macData{ + Mac: digestInfo{ + Digest: []byte{0x18, 0x20, 0x3d, 0xff, 0x1e, 0x16, 0xf4, 0x92, 0xf2, 0xaf, 0xc8, 0x91, 0xa9, 0xba, 0xd6, 0xca, 0x9d, 0xee, 0x51, 0x93}, + }, + MacSalt: []byte{1, 2, 3, 4, 5, 6, 7, 8}, + Iterations: 2048, + } + + message := []byte{11, 12, 13, 14, 15} + password, _ := bmpString("") + + td.Mac.Algorithm.Algorithm = asn1.ObjectIdentifier([]int{1, 2, 3}) + err := verifyMac(&td, message, password) + if _, ok := err.(NotImplementedError); !ok { + t.Errorf("err: %v", err) + } + + td.Mac.Algorithm.Algorithm = asn1.ObjectIdentifier([]int{1, 3, 14, 3, 2, 26}) + err = verifyMac(&td, message, password) + if err != ErrIncorrectPassword { + t.Errorf("Expected incorrect password, got err: %v", err) + } + + password, _ = bmpString("Sesame open") + err = verifyMac(&td, message, password) + if err != nil { + t.Errorf("err: %v", err) + } + +} diff --git a/builder/azure/pkcs12/pbkdf.go b/builder/azure/pkcs12/pbkdf.go new file mode 100644 index 000000000..29685723f --- /dev/null +++ b/builder/azure/pkcs12/pbkdf.go @@ -0,0 +1,187 @@ +package pkcs12 + +import ( + "crypto/sha1" + "math/big" +) + +var ( + deriveKeyByAlg = map[string]func(salt, password []byte, iterations int) []byte{ + pbeWithSHAAnd3KeyTripleDESCBC: func(salt, password []byte, iterations int) []byte { + return pbkdf(sha1Sum, 20, 64, salt, password, iterations, 1, 24) + }, + pbewithSHAAnd40BitRC2CBC: func(salt, password []byte, iterations int) []byte { + return pbkdf(sha1Sum, 20, 64, salt, password, iterations, 1, 5) + }, + } + deriveIVByAlg = map[string]func(salt, password []byte, iterations int) []byte{ + pbeWithSHAAnd3KeyTripleDESCBC: func(salt, password []byte, iterations int) []byte { + return pbkdf(sha1Sum, 20, 64, salt, password, iterations, 2, 8) + }, + pbewithSHAAnd40BitRC2CBC: func(salt, password []byte, iterations int) []byte { + return pbkdf(sha1Sum, 20, 64, salt, password, iterations, 2, 8) + }, + } +) + +func sha1Sum(in []byte) []byte { + sum := sha1.Sum(in) + return sum[:] +} + +func pbkdf(hash func([]byte) []byte, u, v int, salt, password []byte, r int, ID byte, size int) (key []byte) { + // implementation of https://tools.ietf.org/html/rfc7292#appendix-B.2 , RFC text verbatim in comments + + // Let H be a hash function built around a compression function f: + + // Z_2^u x Z_2^v -> Z_2^u + + // (that is, H has a chaining variable and output of length u bits, and + // the message input to the compression function of H is v bits). The + // values for u and v are as follows: + + // HASH FUNCTION VALUE u VALUE v + // MD2, MD5 128 512 + // SHA-1 160 512 + // SHA-224 224 512 + // SHA-256 256 512 + // SHA-384 384 1024 + // SHA-512 512 1024 + // SHA-512/224 224 1024 + // SHA-512/256 256 1024 + + // Furthermore, let r be the iteration count. + + // We assume here that u and v are both multiples of 8, as are the + // lengths of the password and salt strings (which we denote by p and s, + // respectively) and the number n of pseudorandom bits required. In + // addition, u and v are of course non-zero. + + // For information on security considerations for MD5 [19], see [25] and + // [1], and on those for MD2, see [18]. + + // The following procedure can be used to produce pseudorandom bits for + // a particular "purpose" that is identified by a byte called "ID". + // This standard specifies 3 different values for the ID byte: + + // 1. If ID=1, then the pseudorandom bits being produced are to be used + // as key material for performing encryption or decryption. + + // 2. If ID=2, then the pseudorandom bits being produced are to be used + // as an IV (Initial Value) for encryption or decryption. + + // 3. If ID=3, then the pseudorandom bits being produced are to be used + // as an integrity key for MACing. + + // 1. Construct a string, D (the "diversifier"), by concatenating v/8 + // copies of ID. + D := []byte{} + for i := 0; i < v; i++ { + D = append(D, ID) + } + + // 2. Concatenate copies of the salt together to create a string S of + // length v(ceiling(s/v)) bits (the final copy of the salt may be + // truncated to create S). Note that if the salt is the empty + // string, then so is S. + + S := []byte{} + { + s := len(salt) + times := s / v + if s%v > 0 { + times++ + } + for len(S) < times*v { + S = append(S, salt...) + } + S = S[:times*v] + } + + // 3. Concatenate copies of the password together to create a string P + // of length v(ceiling(p/v)) bits (the final copy of the password + // may be truncated to create P). Note that if the password is the + // empty string, then so is P. + + P := []byte{} + { + s := len(password) + times := s / v + if s%v > 0 { + times++ + } + for len(P) < times*v { + P = append(P, password...) + } + password = nil + P = P[:times*v] + } + + // 4. Set I=S||P to be the concatenation of S and P. + I := append(S, P...) + + // 5. Set c=ceiling(n/u). + c := size / u + if size%u > 0 { + c++ + } + + // 6. For i=1, 2, ..., c, do the following: + A := make([]byte, c*20) + for i := 0; i < c; i++ { + + // A. Set A2=H^r(D||I). (i.e., the r-th hash of D||1, + // H(H(H(... H(D||I)))) + Ai := hash(append(D, I...)) + for j := 1; j < r; j++ { + Ai = hash(Ai[:]) + } + copy(A[i*20:], Ai[:]) + + if i < c-1 { // skip on last iteration + + // B. Concatenate copies of Ai to create a string B of length v + // bits (the final copy of Ai may be truncated to create B). + B := []byte{} + for len(B) < v { + B = append(B, Ai[:]...) + } + B = B[:v] + + // C. Treating I as a concatenation I_0, I_1, ..., I_(k-1) of v-bit + // blocks, where k=ceiling(s/v)+ceiling(p/v), modify I by + // setting I_j=(I_j+B+1) mod 2^v for each j. + { + Bbi := new(big.Int) + Bbi.SetBytes(B) + + one := big.NewInt(1) + + for j := 0; j < len(I)/v; j++ { + Ij := new(big.Int) + Ij.SetBytes(I[j*v : (j+1)*v]) + Ij.Add(Ij, Bbi) + Ij.Add(Ij, one) + Ijb := Ij.Bytes() + if len(Ijb) > v { + Ijb = Ijb[len(Ijb)-v:] + } + copy(I[j*v:(j+1)*v], Ijb) + } + } + } + } + // 7. Concatenate A_1, A_2, ..., A_c together to form a pseudorandom + // bit string, A. + + // 8. Use the first n bits of A as the output of this entire process. + A = A[:size] + + return A + + // If the above process is being used to generate a DES key, the process + // should be used to create 64 random bits, and the key's parity bits + // should be set after the 64 bits have been produced. Similar concerns + // hold for 2-key and 3-key triple-DES keys, for CDMF keys, and for any + // similar keys with parity bits "built into them". +} diff --git a/builder/azure/pkcs12/pbkdf_test.go b/builder/azure/pkcs12/pbkdf_test.go new file mode 100644 index 000000000..a52b697d8 --- /dev/null +++ b/builder/azure/pkcs12/pbkdf_test.go @@ -0,0 +1,18 @@ +package pkcs12 + +import ( + "bytes" + "testing" +) + +func TestThatPBKDFWorksCorrectlyForLongKeys(t *testing.T) { + pbkdf := deriveKeyByAlg[pbeWithSHAAnd3KeyTripleDESCBC] + + salt := []byte("\xff\xff\xff\xff\xff\xff\xff\xff") + password, _ := bmpString("sesame") + key := pbkdf(salt, password, 2048) + + if expected := []byte("\x7c\xd9\xfd\x3e\x2b\x3b\xe7\x69\x1a\x44\xe3\xbe\xf0\xf9\xea\x0f\xb9\xb8\x97\xd4\xe3\x25\xd9\xd1"); bytes.Compare(key, expected) != 0 { + t.Fatalf("expected key '% x', but found '% x'", key, expected) + } +} diff --git a/builder/azure/pkcs12/pkcs12.go b/builder/azure/pkcs12/pkcs12.go new file mode 100644 index 000000000..5772b2d4d --- /dev/null +++ b/builder/azure/pkcs12/pkcs12.go @@ -0,0 +1,263 @@ +// Package pkcs12 provides some implementations of PKCS#12. +// +// This implementation is distilled from https://tools.ietf.org/html/rfc7292 and referenced documents. +// It is intended for decoding P12/PFX-stored certificate+key for use with the crypto/tls package. +package pkcs12 + +import ( + "crypto/rand" + "crypto/x509/pkix" + "encoding/asn1" + "errors" + "io" +) + +var ( + oidLocalKeyID = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 9, 21} + oidDataContentType = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 7, 1} + + localKeyId = []byte{0x01, 0x00, 0x00, 0x00} +) + +type pfxPdu struct { + Version int + AuthSafe contentInfo + MacData macData `asn1:"optional"` +} + +type contentInfo struct { + ContentType asn1.ObjectIdentifier + Content asn1.RawValue `asn1:"tag:0,explicit,optional"` +} + +type encryptedData struct { + Version int + EncryptedContentInfo encryptedContentInfo +} + +type encryptedContentInfo struct { + ContentType asn1.ObjectIdentifier + ContentEncryptionAlgorithm pkix.AlgorithmIdentifier + EncryptedContent []byte `asn1:"tag:0,optional"` +} + +func (i encryptedContentInfo) GetAlgorithm() pkix.AlgorithmIdentifier { + return i.ContentEncryptionAlgorithm +} + +func (i encryptedContentInfo) GetData() []byte { return i.EncryptedContent } + +type safeBag struct { + Id asn1.ObjectIdentifier + Value asn1.RawValue `asn1:"tag:0,explicit"` + Attributes []pkcs12Attribute `asn1:"set,optional"` +} + +type pkcs12Attribute struct { + Id asn1.ObjectIdentifier + Value asn1.RawValue `ans1:"set"` +} + +type encryptedPrivateKeyInfo struct { + AlgorithmIdentifier pkix.AlgorithmIdentifier + EncryptedData []byte +} + +func (i encryptedPrivateKeyInfo) GetAlgorithm() pkix.AlgorithmIdentifier { return i.AlgorithmIdentifier } +func (i encryptedPrivateKeyInfo) GetData() []byte { return i.EncryptedData } + +// unmarshal calls asn1.Unmarshal, but also returns an error if there is any +// trailing data after unmarshaling. +func unmarshal(in []byte, out interface{}) error { + trailing, err := asn1.Unmarshal(in, out) + if err != nil { + return err + } + if len(trailing) != 0 { + return errors.New("pkcs12: trailing data found") + } + return nil +} + +func getLocalKeyId(id []byte) (attribute pkcs12Attribute, err error) { + octetString := asn1.RawValue{Tag: 4, Class: 0, IsCompound: false, Bytes: id} + bytes, err := asn1.Marshal(octetString) + if err != nil { + return + } + + attribute = pkcs12Attribute{ + Id: oidLocalKeyID, + Value: asn1.RawValue{Tag: 17, Class: 0, IsCompound: true, Bytes: bytes}, + } + + return attribute, nil +} + +func convertToRawVal(val interface{}) (raw asn1.RawValue, err error) { + bytes, err := asn1.Marshal(val) + if err != nil { + return + } + + _, err = asn1.Unmarshal(bytes, &raw) + return raw, nil +} + +func makeSafeBags(oid asn1.ObjectIdentifier, value []byte) ([]safeBag, error) { + attribute, err := getLocalKeyId(localKeyId) + + if err != nil { + return nil, EncodeError("local key id: " + err.Error()) + } + + bag := make([]safeBag, 1) + bag[0] = safeBag{ + Id: oid, + Value: asn1.RawValue{Tag: 0, Class: 2, IsCompound: true, Bytes: value}, + Attributes: []pkcs12Attribute{attribute}, + } + + return bag, nil +} + +func makeCertBagContentInfo(derBytes []byte) (*contentInfo, error) { + certBag1 := certBag{ + Id: oidCertTypeX509Certificate, + Data: derBytes, + } + + bytes, err := asn1.Marshal(certBag1) + if err != nil { + return nil, EncodeError("encoding cert bag: " + err.Error()) + } + + certSafeBags, err := makeSafeBags(oidCertBagType, bytes) + if err != nil { + return nil, EncodeError("safe bags: " + err.Error()) + } + + return makeContentInfo(certSafeBags) +} + +func makeShroudedKeyBagContentInfo(privateKey interface{}, password []byte) (*contentInfo, error) { + shroudedKeyBagBytes, err := encodePkcs8ShroudedKeyBag(privateKey, password) + if err != nil { + return nil, EncodeError("encode PKCS#8 shrouded key bag: " + err.Error()) + } + + safeBags, err := makeSafeBags(oidPkcs8ShroudedKeyBagType, shroudedKeyBagBytes) + if err != nil { + return nil, EncodeError("safe bags: " + err.Error()) + } + + return makeContentInfo(safeBags) +} + +func makeContentInfo(val interface{}) (*contentInfo, error) { + fullBytes, err := asn1.Marshal(val) + if err != nil { + return nil, EncodeError("contentInfo raw value marshal: " + err.Error()) + } + + octetStringVal := asn1.RawValue{Tag: 4, Class: 0, IsCompound: false, Bytes: fullBytes} + octetStringFullBytes, err := asn1.Marshal(octetStringVal) + if err != nil { + return nil, EncodeError("raw contentInfo to octet string: " + err.Error()) + } + + contentInfo := contentInfo{ContentType: oidDataContentType} + contentInfo.Content = asn1.RawValue{Tag: 0, Class: 2, IsCompound: true, Bytes: octetStringFullBytes} + + return &contentInfo, nil +} + +func makeContentInfos(derBytes []byte, privateKey interface{}, password []byte) ([]contentInfo, error) { + shroudedKeyContentInfo, err := makeShroudedKeyBagContentInfo(privateKey, password) + if err != nil { + return nil, EncodeError("shrouded key content info: " + err.Error()) + } + + certBagContentInfo, err := makeCertBagContentInfo(derBytes) + if err != nil { + return nil, EncodeError("cert bag content info: " + err.Error()) + } + + contentInfos := make([]contentInfo, 2) + contentInfos[0] = *shroudedKeyContentInfo + contentInfos[1] = *certBagContentInfo + + return contentInfos, nil +} + +func makeSalt(saltByteCount int) ([]byte, error) { + salt := make([]byte, saltByteCount) + _, err := io.ReadFull(rand.Reader, salt) + return salt, err +} + +// Encode converts a certificate and a private key to the PKCS#12 byte stream format. +// +// derBytes is a DER encoded certificate. +// privateKey is an RSA +func Encode(derBytes []byte, privateKey interface{}, password string) (pfxBytes []byte, err error) { + secret, err := bmpString(password) + if err != nil { + return nil, ErrIncorrectPassword + } + + contentInfos, err := makeContentInfos(derBytes, privateKey, secret) + if err != nil { + return nil, err + } + + // Marhsal []contentInfo so we can re-constitute the byte stream that will + // be suitable for computing the MAC + bytes, err := asn1.Marshal(contentInfos) + if err != nil { + return nil, err + } + + // Unmarshal as an asn1.RawValue so, we can compute the MAC against the .Bytes + var contentInfosRaw asn1.RawValue + err = unmarshal(bytes, &contentInfosRaw) + if err != nil { + return nil, err + } + + authSafeContentInfo, err := makeContentInfo(contentInfosRaw) + if err != nil { + return nil, EncodeError("authSafe content info: " + err.Error()) + } + + salt, err := makeSalt(pbeSaltSizeBytes) + if err != nil { + return nil, EncodeError("salt value: " + err.Error()) + } + + // Compute the MAC for marshaled bytes of contentInfos, which includes the + // cert bag, and the shrouded key bag. + digest := computeMac(contentInfosRaw.FullBytes, pbeIterationCount, salt, secret) + + pfx := pfxPdu{ + Version: 3, + AuthSafe: *authSafeContentInfo, + MacData: macData{ + Iterations: pbeIterationCount, + MacSalt: salt, + Mac: digestInfo{ + Algorithm: pkix.AlgorithmIdentifier{ + Algorithm: oidSha1Algorithm, + }, + Digest: digest, + }, + }, + } + + bytes, err = asn1.Marshal(pfx) + if err != nil { + return nil, EncodeError("marshal PFX PDU: " + err.Error()) + } + + return bytes, err +} diff --git a/builder/azure/pkcs12/pkcs12_test.go b/builder/azure/pkcs12/pkcs12_test.go new file mode 100644 index 000000000..b48e618a0 --- /dev/null +++ b/builder/azure/pkcs12/pkcs12_test.go @@ -0,0 +1,117 @@ +package pkcs12 + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "fmt" + "math/big" + "testing" + "time" + + gopkcs12 "golang.org/x/crypto/pkcs12" +) + +func TestPfxRoundTriRsa(t *testing.T) { + privateKey, err := rsa.GenerateKey(rand.Reader, 512) + if err != nil { + t.Fatal(err.Error()) + } + + key := testPfxRoundTrip(t, privateKey) + + actualPrivateKey, ok := key.(*rsa.PrivateKey) + if !ok { + t.Fatalf("failed to decode private key") + } + + if privateKey.D.Cmp(actualPrivateKey.D) != 0 { + t.Errorf("priv.D") + } +} + +func TestPfxRoundTriEcdsa(t *testing.T) { + privateKey, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader) + if err != nil { + t.Fatal(err.Error()) + } + + key := testPfxRoundTrip(t, privateKey) + + actualPrivateKey, ok := key.(*ecdsa.PrivateKey) + if !ok { + t.Fatalf("failed to decode private key") + } + + if privateKey.D.Cmp(actualPrivateKey.D) != 0 { + t.Errorf("priv.D") + } +} + +func testPfxRoundTrip(t *testing.T, privateKey interface{}) interface{} { + certificateBytes, err := newCertificate("hostname", privateKey) + if err != nil { + t.Fatal(err.Error()) + } + + bytes, err := Encode(certificateBytes, privateKey, "sesame") + if err != nil { + t.Fatal(err.Error()) + } + + key, _, err := gopkcs12.Decode(bytes, "sesame") + if err != nil { + t.Fatalf(err.Error()) + } + + return key +} + +func newCertificate(hostname string, privateKey interface{}) ([]byte, error) { + t, _ := time.Parse("2006-01-02", "2016-01-01") + notBefore := t + notAfter := notBefore.Add(365 * 24 * time.Hour) + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + err := fmt.Errorf("Failed to Generate Serial Number: %v", err) + return nil, err + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Issuer: pkix.Name{ + CommonName: hostname, + }, + Subject: pkix.Name{ + CommonName: hostname, + }, + + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + var publicKey interface{} + switch key := privateKey.(type) { + case *rsa.PrivateKey: + publicKey = key.Public() + case *ecdsa.PrivateKey: + publicKey = key.Public() + default: + panic(fmt.Sprintf("unsupported private key type: %T", privateKey)) + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey, privateKey) + if err != nil { + return nil, fmt.Errorf("Failed to Generate derBytes: " + err.Error()) + } + + return derBytes, nil +} diff --git a/builder/azure/pkcs12/pkcs8.go b/builder/azure/pkcs12/pkcs8.go new file mode 100644 index 000000000..1eb8850f9 --- /dev/null +++ b/builder/azure/pkcs12/pkcs8.go @@ -0,0 +1,64 @@ +package pkcs12 + +import ( + "crypto/ecdsa" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "errors" +) + +// pkcs8 reflects an ASN.1, PKCS#8 PrivateKey. See +// ftp://ftp.rsasecurity.com/pub/pkcs/pkcs-8/pkcs-8v1_2.asn +// and RFC5208. +type pkcs8 struct { + Version int + Algo pkix.AlgorithmIdentifier + PrivateKey []byte + // optional attributes omitted. +} + +var ( + oidPublicKeyRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1} + oidPublicKeyECDSA = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1} + + nullAsn = asn1.RawValue{Tag: 5} +) + +// marshalPKCS8PrivateKey converts a private key to PKCS#8 encoded form. +// See http://www.rsa.com/rsalabs/node.asp?id=2130 and RFC5208. +func marshalPKCS8PrivateKey(key interface{}) (der []byte, err error) { + pkcs := pkcs8{ + Version: 0, + } + + switch key := key.(type) { + case *rsa.PrivateKey: + pkcs.Algo = pkix.AlgorithmIdentifier{ + Algorithm: oidPublicKeyRSA, + Parameters: nullAsn, + } + pkcs.PrivateKey = x509.MarshalPKCS1PrivateKey(key) + case *ecdsa.PrivateKey: + bytes, err := x509.MarshalECPrivateKey(key) + if err != nil { + return nil, errors.New("x509: failed to marshal to PKCS#8: " + err.Error()) + } + + pkcs.Algo = pkix.AlgorithmIdentifier{ + Algorithm: oidPublicKeyECDSA, + Parameters: nullAsn, + } + pkcs.PrivateKey = bytes + default: + return nil, errors.New("x509: PKCS#8 only RSA and ECDSA private keys supported") + } + + bytes, err := asn1.Marshal(pkcs) + if err != nil { + return nil, errors.New("x509: failed to marshal to PKCS#8: " + err.Error()) + } + + return bytes, nil +} diff --git a/builder/azure/pkcs12/pkcs8_test.go b/builder/azure/pkcs12/pkcs8_test.go new file mode 100644 index 000000000..7d12119f5 --- /dev/null +++ b/builder/azure/pkcs12/pkcs8_test.go @@ -0,0 +1,141 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +package pkcs12 + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/asn1" + "testing" +) + +func TestRoundTripPkcs8Rsa(t *testing.T) { + privateKey, err := rsa.GenerateKey(rand.Reader, 512) + if err != nil { + t.Fatalf("failed to generate a private key: %s", err) + } + + bytes, err := marshalPKCS8PrivateKey(privateKey) + if err != nil { + t.Fatalf("failed to marshal private key: %s", err) + } + + key, err := x509.ParsePKCS8PrivateKey(bytes) + if err != nil { + t.Fatalf("failed to parse private key: %s", err) + } + + actualPrivateKey, ok := key.(*rsa.PrivateKey) + if !ok { + t.Fatalf("expected key to be of type *rsa.PrivateKey, but actual was %T", key) + } + + if actualPrivateKey.Validate() != nil { + t.Fatalf("private key did not validate") + } + + if actualPrivateKey.N.Cmp(privateKey.N) != 0 { + t.Errorf("private key's N did not round trip") + } + if actualPrivateKey.D.Cmp(privateKey.D) != 0 { + t.Errorf("private key's D did not round trip") + } + if actualPrivateKey.E != privateKey.E { + t.Errorf("private key's E did not round trip") + } + if actualPrivateKey.Primes[0].Cmp(privateKey.Primes[0]) != 0 { + t.Errorf("private key's P did not round trip") + } + if actualPrivateKey.Primes[1].Cmp(privateKey.Primes[1]) != 0 { + t.Errorf("private key's Q did not round trip") + } +} + +func TestRoundTripPkcs8Ecdsa(t *testing.T) { + privateKey, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader) + if err != nil { + t.Fatalf("failed to generate a private key: %s", err) + } + + bytes, err := marshalPKCS8PrivateKey(privateKey) + if err != nil { + t.Fatalf("failed to marshal private key: %s", err) + } + + key, err := x509.ParsePKCS8PrivateKey(bytes) + if err != nil { + t.Fatalf("failed to parse private key: %s", err) + } + + actualPrivateKey, ok := key.(*ecdsa.PrivateKey) + if !ok { + t.Fatalf("expected key to be of type *ecdsa.PrivateKey, but actual was %T", key) + } + + // sanity check, not exhaustive + if actualPrivateKey.D.Cmp(privateKey.D) != 0 { + t.Errorf("private key's D did not round trip") + } + if actualPrivateKey.X.Cmp(privateKey.X) != 0 { + t.Errorf("private key's X did not round trip") + } + if actualPrivateKey.Y.Cmp(privateKey.Y) != 0 { + t.Errorf("private key's Y did not round trip") + } + if actualPrivateKey.Curve.Params().B.Cmp(privateKey.Curve.Params().B) != 0 { + t.Errorf("private key's Curve.B did not round trip") + } +} + +func TestNullParametersPkcs8Rsa(t *testing.T) { + privateKey, err := rsa.GenerateKey(rand.Reader, 512) + if err != nil { + t.Fatalf("failed to generate a private key: %s", err) + } + + checkNullParameter(t, privateKey) +} + +func TestNullParametersPkcs8Ecdsa(t *testing.T) { + privateKey, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader) + if err != nil { + t.Fatalf("failed to generate a private key: %s", err) + } + + checkNullParameter(t, privateKey) +} + +func checkNullParameter(t *testing.T, privateKey interface{}) { + bytes, err := marshalPKCS8PrivateKey(privateKey) + if err != nil { + t.Fatalf("failed to marshal private key: %s", err) + } + + var pkcs pkcs8 + rest, err := asn1.Unmarshal(bytes, &pkcs) + if err != nil { + t.Fatalf("failed to unmarshal PKCS#8: %s", err) + } + + if len(rest) != 0 { + t.Fatalf("unexpected trailing bytes of len=%d, bytes=%x", len(rest), rest) + } + + // Only version == 0 is known and valid + if pkcs.Version != 0 { + t.Errorf("expected version=0, but actual=%d", pkcs.Version) + } + + // ensure a NULL parameter is inserted + if pkcs.Algo.Parameters.Tag != 5 { + t.Errorf("expected parameters to be NULL, but actual tag=%d, class=%d, isCompound=%t, bytes=%x", + pkcs.Algo.Parameters.Tag, + pkcs.Algo.Parameters.Class, + pkcs.Algo.Parameters.IsCompound, + pkcs.Algo.Parameters.Bytes) + } +} diff --git a/builder/azure/pkcs12/rc2/rc2.go b/builder/azure/pkcs12/rc2/rc2.go new file mode 100644 index 000000000..aa194e501 --- /dev/null +++ b/builder/azure/pkcs12/rc2/rc2.go @@ -0,0 +1,284 @@ +// Package rc2 implements the RC2 cipher +/* +https://www.ietf.org/rfc/rfc2268.txt +http://people.csail.mit.edu/rivest/pubs/KRRR98.pdf + +This code is licensed under the MIT license. +*/ +package rc2 + +import ( + "crypto/cipher" + "encoding/binary" + "strconv" +) + +// The rc2 block size in bytes +const BlockSize = 8 + +type rc2Cipher struct { + k [64]uint16 +} + +// KeySizeError indicates the supplied key was invalid +type KeySizeError int + +func (k KeySizeError) Error() string { return "rc2: invalid key size " + strconv.Itoa(int(k)) } + +// EffectiveKeySizeError indicates the supplied effective key length was invalid +type EffectiveKeySizeError int + +func (k EffectiveKeySizeError) Error() string { + return "rc2: invalid effective key size " + strconv.Itoa(int(k)) +} + +// New returns a new rc2 cipher with the given key and effective key length t1 +func New(key []byte, t1 int) (cipher.Block, error) { + if l := len(key); l == 0 || l > 128 { + return nil, KeySizeError(l) + } + + if t1 < 8 || t1 > 1024 { + return nil, EffectiveKeySizeError(t1) + } + + return &rc2Cipher{ + k: expandKey(key, t1), + }, nil +} + +func (c *rc2Cipher) BlockSize() int { return BlockSize } + +var piTable = [256]byte{ + 0xd9, 0x78, 0xf9, 0xc4, 0x19, 0xdd, 0xb5, 0xed, 0x28, 0xe9, 0xfd, 0x79, 0x4a, 0xa0, 0xd8, 0x9d, + 0xc6, 0x7e, 0x37, 0x83, 0x2b, 0x76, 0x53, 0x8e, 0x62, 0x4c, 0x64, 0x88, 0x44, 0x8b, 0xfb, 0xa2, + 0x17, 0x9a, 0x59, 0xf5, 0x87, 0xb3, 0x4f, 0x13, 0x61, 0x45, 0x6d, 0x8d, 0x09, 0x81, 0x7d, 0x32, + 0xbd, 0x8f, 0x40, 0xeb, 0x86, 0xb7, 0x7b, 0x0b, 0xf0, 0x95, 0x21, 0x22, 0x5c, 0x6b, 0x4e, 0x82, + 0x54, 0xd6, 0x65, 0x93, 0xce, 0x60, 0xb2, 0x1c, 0x73, 0x56, 0xc0, 0x14, 0xa7, 0x8c, 0xf1, 0xdc, + 0x12, 0x75, 0xca, 0x1f, 0x3b, 0xbe, 0xe4, 0xd1, 0x42, 0x3d, 0xd4, 0x30, 0xa3, 0x3c, 0xb6, 0x26, + 0x6f, 0xbf, 0x0e, 0xda, 0x46, 0x69, 0x07, 0x57, 0x27, 0xf2, 0x1d, 0x9b, 0xbc, 0x94, 0x43, 0x03, + 0xf8, 0x11, 0xc7, 0xf6, 0x90, 0xef, 0x3e, 0xe7, 0x06, 0xc3, 0xd5, 0x2f, 0xc8, 0x66, 0x1e, 0xd7, + 0x08, 0xe8, 0xea, 0xde, 0x80, 0x52, 0xee, 0xf7, 0x84, 0xaa, 0x72, 0xac, 0x35, 0x4d, 0x6a, 0x2a, + 0x96, 0x1a, 0xd2, 0x71, 0x5a, 0x15, 0x49, 0x74, 0x4b, 0x9f, 0xd0, 0x5e, 0x04, 0x18, 0xa4, 0xec, + 0xc2, 0xe0, 0x41, 0x6e, 0x0f, 0x51, 0xcb, 0xcc, 0x24, 0x91, 0xaf, 0x50, 0xa1, 0xf4, 0x70, 0x39, + 0x99, 0x7c, 0x3a, 0x85, 0x23, 0xb8, 0xb4, 0x7a, 0xfc, 0x02, 0x36, 0x5b, 0x25, 0x55, 0x97, 0x31, + 0x2d, 0x5d, 0xfa, 0x98, 0xe3, 0x8a, 0x92, 0xae, 0x05, 0xdf, 0x29, 0x10, 0x67, 0x6c, 0xba, 0xc9, + 0xd3, 0x00, 0xe6, 0xcf, 0xe1, 0x9e, 0xa8, 0x2c, 0x63, 0x16, 0x01, 0x3f, 0x58, 0xe2, 0x89, 0xa9, + 0x0d, 0x38, 0x34, 0x1b, 0xab, 0x33, 0xff, 0xb0, 0xbb, 0x48, 0x0c, 0x5f, 0xb9, 0xb1, 0xcd, 0x2e, + 0xc5, 0xf3, 0xdb, 0x47, 0xe5, 0xa5, 0x9c, 0x77, 0x0a, 0xa6, 0x20, 0x68, 0xfe, 0x7f, 0xc1, 0xad, +} + +func expandKey(key []byte, t1 int) [64]uint16 { + + l := make([]byte, 128) + copy(l, key) + + var t = len(key) + var t8 = (t1 + 7) / 8 + var tm = byte(255 % uint(1<<(8+uint(t1)-8*uint(t8)))) + + for i := len(key); i < 128; i++ { + l[i] = piTable[l[i-1]+l[uint8(i-t)]] + } + + l[128-t8] = piTable[l[128-t8]&tm] + + for i := 127 - t8; i >= 0; i-- { + l[i] = piTable[l[i+1]^l[i+t8]] + } + + var k [64]uint16 + + for i := range k { + k[i] = uint16(l[2*i]) + uint16(l[2*i+1])*256 + } + + return k +} + +func rotl16(x uint16, b uint) uint16 { + return (x >> (16 - b)) | (x << b) +} + +func (c *rc2Cipher) Encrypt(dst, src []byte) { + + r0 := binary.LittleEndian.Uint16(src[0:]) + r1 := binary.LittleEndian.Uint16(src[2:]) + r2 := binary.LittleEndian.Uint16(src[4:]) + r3 := binary.LittleEndian.Uint16(src[6:]) + + var j int + + // These three mix blocks have not been extracted to a common function for to performance reasons. + for j <= 16 { + // mix r0 + r0 = r0 + c.k[j] + (r3 & r2) + ((^r3) & r1) + r0 = rotl16(r0, 1) + j++ + + // mix r1 + r1 = r1 + c.k[j] + (r0 & r3) + ((^r0) & r2) + r1 = rotl16(r1, 2) + j++ + + // mix r2 + r2 = r2 + c.k[j] + (r1 & r0) + ((^r1) & r3) + r2 = rotl16(r2, 3) + j++ + + // mix r3 + r3 = r3 + c.k[j] + (r2 & r1) + ((^r2) & r0) + r3 = rotl16(r3, 5) + j++ + } + + r0 = r0 + c.k[r3&63] + r1 = r1 + c.k[r0&63] + r2 = r2 + c.k[r1&63] + r3 = r3 + c.k[r2&63] + + for j <= 40 { + // mix r0 + r0 = r0 + c.k[j] + (r3 & r2) + ((^r3) & r1) + r0 = rotl16(r0, 1) + j++ + + // mix r1 + r1 = r1 + c.k[j] + (r0 & r3) + ((^r0) & r2) + r1 = rotl16(r1, 2) + j++ + + // mix r2 + r2 = r2 + c.k[j] + (r1 & r0) + ((^r1) & r3) + r2 = rotl16(r2, 3) + j++ + + // mix r3 + r3 = r3 + c.k[j] + (r2 & r1) + ((^r2) & r0) + r3 = rotl16(r3, 5) + j++ + } + + r0 = r0 + c.k[r3&63] + r1 = r1 + c.k[r0&63] + r2 = r2 + c.k[r1&63] + r3 = r3 + c.k[r2&63] + + for j <= 60 { + // mix r0 + r0 = r0 + c.k[j] + (r3 & r2) + ((^r3) & r1) + r0 = rotl16(r0, 1) + j++ + + // mix r1 + r1 = r1 + c.k[j] + (r0 & r3) + ((^r0) & r2) + r1 = rotl16(r1, 2) + j++ + + // mix r2 + r2 = r2 + c.k[j] + (r1 & r0) + ((^r1) & r3) + r2 = rotl16(r2, 3) + j++ + + // mix r3 + r3 = r3 + c.k[j] + (r2 & r1) + ((^r2) & r0) + r3 = rotl16(r3, 5) + j++ + } + + binary.LittleEndian.PutUint16(dst[0:], r0) + binary.LittleEndian.PutUint16(dst[2:], r1) + binary.LittleEndian.PutUint16(dst[4:], r2) + binary.LittleEndian.PutUint16(dst[6:], r3) +} + +func (c *rc2Cipher) Decrypt(dst, src []byte) { + + r0 := binary.LittleEndian.Uint16(src[0:]) + r1 := binary.LittleEndian.Uint16(src[2:]) + r2 := binary.LittleEndian.Uint16(src[4:]) + r3 := binary.LittleEndian.Uint16(src[6:]) + + j := 63 + + for j >= 44 { + // unmix r3 + r3 = rotl16(r3, 16-5) + r3 = r3 - c.k[j] - (r2 & r1) - ((^r2) & r0) + j-- + + // unmix r2 + r2 = rotl16(r2, 16-3) + r2 = r2 - c.k[j] - (r1 & r0) - ((^r1) & r3) + j-- + + // unmix r1 + r1 = rotl16(r1, 16-2) + r1 = r1 - c.k[j] - (r0 & r3) - ((^r0) & r2) + j-- + + // unmix r0 + r0 = rotl16(r0, 16-1) + r0 = r0 - c.k[j] - (r3 & r2) - ((^r3) & r1) + j-- + } + + r3 = r3 - c.k[r2&63] + r2 = r2 - c.k[r1&63] + r1 = r1 - c.k[r0&63] + r0 = r0 - c.k[r3&63] + + for j >= 20 { + // unmix r3 + r3 = rotl16(r3, 16-5) + r3 = r3 - c.k[j] - (r2 & r1) - ((^r2) & r0) + j-- + + // unmix r2 + r2 = rotl16(r2, 16-3) + r2 = r2 - c.k[j] - (r1 & r0) - ((^r1) & r3) + j-- + + // unmix r1 + r1 = rotl16(r1, 16-2) + r1 = r1 - c.k[j] - (r0 & r3) - ((^r0) & r2) + j-- + + // unmix r0 + r0 = rotl16(r0, 16-1) + r0 = r0 - c.k[j] - (r3 & r2) - ((^r3) & r1) + j-- + } + + r3 = r3 - c.k[r2&63] + r2 = r2 - c.k[r1&63] + r1 = r1 - c.k[r0&63] + r0 = r0 - c.k[r3&63] + + for j >= 0 { + // unmix r3 + r3 = rotl16(r3, 16-5) + r3 = r3 - c.k[j] - (r2 & r1) - ((^r2) & r0) + j-- + + // unmix r2 + r2 = rotl16(r2, 16-3) + r2 = r2 - c.k[j] - (r1 & r0) - ((^r1) & r3) + j-- + + // unmix r1 + r1 = rotl16(r1, 16-2) + r1 = r1 - c.k[j] - (r0 & r3) - ((^r0) & r2) + j-- + + // unmix r0 + r0 = rotl16(r0, 16-1) + r0 = r0 - c.k[j] - (r3 & r2) - ((^r3) & r1) + j-- + } + + binary.LittleEndian.PutUint16(dst[0:], r0) + binary.LittleEndian.PutUint16(dst[2:], r1) + binary.LittleEndian.PutUint16(dst[4:], r2) + binary.LittleEndian.PutUint16(dst[6:], r3) +} diff --git a/builder/azure/pkcs12/rc2/rc2_test.go b/builder/azure/pkcs12/rc2/rc2_test.go new file mode 100644 index 000000000..adb73c592 --- /dev/null +++ b/builder/azure/pkcs12/rc2/rc2_test.go @@ -0,0 +1,105 @@ +package rc2 + +import ( + "bytes" + "encoding/hex" + "testing" +) + +func TestEncryptDecrypt(t *testing.T) { + + var tests = []struct { + key string + plain string + cipher string + t1 int + }{ + { + "0000000000000000", + "0000000000000000", + "ebb773f993278eff", + 63, + }, + { + "ffffffffffffffff", + "ffffffffffffffff", + "278b27e42e2f0d49", + 64, + }, + { + "3000000000000000", + "1000000000000001", + "30649edf9be7d2c2", + 64, + }, + { + "88", + "0000000000000000", + "61a8a244adacccf0", + 64, + }, + { + "88bca90e90875a", + "0000000000000000", + "6ccf4308974c267f", + 64, + }, + { + "88bca90e90875a7f0f79c384627bafb2", + "0000000000000000", + "1a807d272bbe5db1", + 64, + }, + { + "88bca90e90875a7f0f79c384627bafb2", + "0000000000000000", + "2269552ab0f85ca6", + 128, + }, + { + "88bca90e90875a7f0f79c384627bafb216f80a6f85920584c42fceb0be255daf1e", + "0000000000000000", + "5b78d3a43dfff1f1", + 129, + }, + } + + for _, tt := range tests { + k, _ := hex.DecodeString(tt.key) + p, _ := hex.DecodeString(tt.plain) + c, _ := hex.DecodeString(tt.cipher) + + b, _ := New(k, tt.t1) + + var dst [8]byte + + b.Encrypt(dst[:], p) + + if !bytes.Equal(dst[:], c) { + t.Errorf("encrypt failed: got % 2x wanted % 2x\n", dst, c) + } + + b.Decrypt(dst[:], c) + + if !bytes.Equal(dst[:], p) { + t.Errorf("decrypt failed: got % 2x wanted % 2x\n", dst, p) + } + } +} + +func BenchmarkEncrypt(b *testing.B) { + r, _ := New([]byte{0, 0, 0, 0, 0, 0, 0, 0}, 64) + b.ResetTimer() + var src [8]byte + for i := 0; i < b.N; i++ { + r.Encrypt(src[:], src[:]) + } +} +func BenchmarkDecrypt(b *testing.B) { + r, _ := New([]byte{0, 0, 0, 0, 0, 0, 0, 0}, 64) + b.ResetTimer() + var src [8]byte + for i := 0; i < b.N; i++ { + r.Decrypt(src[:], src[:]) + } +} diff --git a/builder/azure/pkcs12/safebags.go b/builder/azure/pkcs12/safebags.go new file mode 100644 index 000000000..5d4793b3e --- /dev/null +++ b/builder/azure/pkcs12/safebags.go @@ -0,0 +1,67 @@ +package pkcs12 + +import ( + "crypto/x509/pkix" + "encoding/asn1" + "errors" +) + +//see https://tools.ietf.org/html/rfc7292#appendix-D +var ( + oidPkcs8ShroudedKeyBagType = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 12, 10, 1, 2} + oidCertBagType = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 12, 10, 1, 3} + + oidCertTypeX509Certificate = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 9, 22, 1} +) + +type certBag struct { + Id asn1.ObjectIdentifier + Data []byte `asn1:"tag:0,explicit"` +} + +func getAlgorithmParams(salt []byte, iterations int) (asn1.RawValue, error) { + params := pbeParams{ + Salt: salt, + Iterations: iterations, + } + + return convertToRawVal(params) +} + +func encodePkcs8ShroudedKeyBag(privateKey interface{}, password []byte) (bytes []byte, err error) { + privateKeyBytes, err := marshalPKCS8PrivateKey(privateKey) + + if err != nil { + return nil, errors.New("pkcs12: error encoding PKCS#8 private key: " + err.Error()) + } + + salt, err := makeSalt(pbeSaltSizeBytes) + if err != nil { + return nil, errors.New("pkcs12: error creating PKCS#8 salt: " + err.Error()) + } + + pkData, err := pbEncrypt(privateKeyBytes, salt, password, pbeIterationCount) + if err != nil { + return nil, errors.New("pkcs12: error encoding PKCS#8 shrouded key bag when encrypting cert bag: " + err.Error()) + } + + params, err := getAlgorithmParams(salt, pbeIterationCount) + if err != nil { + return nil, errors.New("pkcs12: error encoding PKCS#8 shrouded key bag algorithm's parameters: " + err.Error()) + } + + pkinfo := encryptedPrivateKeyInfo{ + AlgorithmIdentifier: pkix.AlgorithmIdentifier{ + Algorithm: oidPbeWithSHAAnd3KeyTripleDESCBC, + Parameters: params, + }, + EncryptedData: pkData, + } + + bytes, err = asn1.Marshal(pkinfo) + if err != nil { + return nil, errors.New("pkcs12: error encoding PKCS#8 shrouded key bag: " + err.Error()) + } + + return bytes, err +} diff --git a/builder/azure/pkcs12/safebags_test.go b/builder/azure/pkcs12/safebags_test.go new file mode 100644 index 000000000..bb30d4e55 --- /dev/null +++ b/builder/azure/pkcs12/safebags_test.go @@ -0,0 +1,102 @@ +package pkcs12 + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/asn1" + "fmt" + "testing" +) + +func decodePkcs8ShroudedKeyBag(asn1Data, password []byte) (privateKey interface{}, err error) { + pkinfo := new(encryptedPrivateKeyInfo) + if _, err = asn1.Unmarshal(asn1Data, pkinfo); err != nil { + err = fmt.Errorf("error decoding PKCS8 shrouded key bag: %v", err) + return nil, err + } + + pkData, err := pbDecrypt(pkinfo, password) + if err != nil { + err = fmt.Errorf("error decrypting PKCS8 shrouded key bag: %v", err) + return + } + + rv := new(asn1.RawValue) + if _, err = asn1.Unmarshal(pkData, rv); err != nil { + err = fmt.Errorf("could not decode decrypted private key data") + } + + if privateKey, err = x509.ParsePKCS8PrivateKey(pkData); err != nil { + err = fmt.Errorf("error parsing PKCS8 private key: %v", err) + return nil, err + } + return +} + +// Assert the default algorithm parameters are in the correct order, +// and default to the correct value. Defaults are based on OpenSSL. +// 1. IterationCount, defaults to 2,048 long. +// 2. Salt, is 8 bytes long. +func TestDefaultAlgorithmParametersPkcs8ShroudedKeyBag(t *testing.T) { + privateKey, err := rsa.GenerateKey(rand.Reader, 512) + if err != nil { + t.Fatalf("failed to generate a private key: %s", err) + } + + password := []byte("sesame") + bytes, err := encodePkcs8ShroudedKeyBag(privateKey, password) + if err != nil { + t.Fatalf("failed to encode PKCS#8 shrouded key bag: %s", err) + } + + var pkinfo encryptedPrivateKeyInfo + rest, err := asn1.Unmarshal(bytes, &pkinfo) + if err != nil { + t.Fatalf("failed to unmarshal encryptedPrivateKeyInfo %s", err) + } + + if len(rest) != 0 { + t.Fatalf("unexpected trailing bytes of len=%d, bytes=%x", len(rest), rest) + } + + var params pbeParams + rest, err = asn1.Unmarshal(pkinfo.GetAlgorithm().Parameters.FullBytes, ¶ms) + if err != nil { + t.Fatalf("failed to unmarshal encryptedPrivateKeyInfo %s", err) + } + + if len(rest) != 0 { + t.Fatalf("unexpected trailing bytes of len=%d, bytes=%x", len(rest), rest) + } + + if params.Iterations != pbeIterationCount { + t.Errorf("expected iteration count to be %d, but actual=%d", pbeIterationCount, params.Iterations) + } + if len(params.Salt) != pbeSaltSizeBytes { + t.Errorf("expected the number of salt bytes to be %d, but actual=%d", pbeSaltSizeBytes, len(params.Salt)) + } +} + +func TestRoundTripPkcs8ShroudedKeyBag(t *testing.T) { + privateKey, err := rsa.GenerateKey(rand.Reader, 512) + if err != nil { + t.Fatalf("failed to generate a private key: %s", err) + } + + password := []byte("sesame") + bytes, err := encodePkcs8ShroudedKeyBag(privateKey, password) + if err != nil { + t.Fatalf("failed to encode PKCS#8 shrouded key bag: %s", err) + } + + key, err := decodePkcs8ShroudedKeyBag(bytes, password) + if err != nil { + t.Fatalf("failed to decode PKCS#8 shrouded key bag: %s", err) + } + + actualPrivateKey := key.(*rsa.PrivateKey) + if actualPrivateKey.D.Cmp(privateKey.D) != 0 { + t.Fatalf("failed to round-trip rsa.PrivateKey.D") + } +}