Add support for Windows to Azure.

This is last merge that will happen from the github.com/Azure/packer-Azure
repository.  All development is being over to this repository.

The biggest change in this merge is support for Windows.  There are a few other
fixes as well.

 * If the user cancels the build, clean up any resources.
 * Output a reasonable build artifact.
 * Log requests and responses with Azure.
 * Support for US Government and the China clouds.
 * Support interrupting long running tasks.
 * Allow the user to set the image version.
 * Device login support.
This commit is contained in:
Christopher Boumenot 2016-04-21 16:50:03 -07:00 committed by Chris Bednarski
parent 0abf0bbab0
commit c7018a00c8
82 changed files with 4773 additions and 393 deletions

View File

@ -6,21 +6,91 @@ Please see the
[overview](https://azure.microsoft.com/en-us/documentation/articles/resource-group-overview/)
for more information about ARM as well as the benefit of ARM.
## Getting Started
## Device Login vs. Service Principal Name (SPN)
The ARM APIs use OAUTH to authenticate, so you must create a Service
Principal. The following articles are a good starting points.
There are two ways to get started with packer-azure. The simplest is device login, and only requires a Subscription ID.
Device login is only supported for Linux based VMs. The second is the use of an SPN. We recommend the device login
approach for those who are first time users, and just want to ''kick the tires.'' We recommend the SPN approach if you
intend to automate Packer, or you are deploying Windows VMs.
## Device Login
A sample template for device login is show below. There are three pieces of information
you must provide to enable device login mode.
1. SubscriptionID
1. Resource Group - parent resource group that Packer uses to build an image.
1. Storage Account - storage account where the image will be placed.
> Device login mode is enabled by not setting client_id, client_secret, and tenant_id.
The device login flow asks that you open a web browser, navigate to http://aka.ms/devicelogin, and input the supplied
code. This authorizes the Packer for Azure application to act on your behalf. An OAuth token will be created, and
stored in the user's home directory (~/.azure/packer/oauth-TenantID.json, and TenantID will be replaced with the actual
Tenant ID). This token is used if it exists, and refreshed as necessary.
```json
{
"variables": {
"sid": "your_subscription_id",
"rgn": "your_resource_group",
"sa": "your_storage_account"
},
"builders": [
{
"type": "azure-arm",
"subscription_id": "{{user `sid`}}",
"resource_group_name": "{{user `rgn`}}",
"storage_account": "{{user `sa`}}",
"capture_container_name": "images",
"capture_name_prefix": "packer",
"os_type": "Linux",
"image_publisher": "Canonical",
"image_offer": "UbuntuServer",
"image_sku": "14.04.3-LTS",
"location": "South Central US",
"vm_size": "Standard_A2"
}
],
"provisioners": [
{
"execute_command": "chmod +x {{ .Path }}; {{ .Vars }} sudo -E sh '{{ .Path }}'",
"inline": [
"apt-get update",
"apt-get upgrade -y",
"/usr/sbin/waagent -force -deprovision+user && export HISTSIZE=0 && sync"
],
"inline_shebang": "/bin/sh -x",
"type": "shell"
}
]
}
```
## Service Principal Name
The ARM APIs use OAUTH to authenticate, and requires an SPN. The following articles
are a good starting points for creating a new SPN.
* [Automating Azure on your CI server using a Service Principal](http://blog.davidebbo.com/2014/12/azure-service-principal.html)
* [Authenticating a service principal with Azure Resource Manager](https://azure.microsoft.com/en-us/documentation/articles/resource-group-authenticate-service-principal/)
There are three pieces of configuration you will need as a result of
creating a Service Principal.
There are three (four in the case of Windows) pieces of configuration you need to note
after creating an SPN.
1. Client ID (aka Service Principal ID)
1. Client Secret (aka Service Principal generated key)
1. Client Tenant (aka Azure Active Directory tenant that owns the
Service Principal)
1. Object ID (Windows only) - a certificate is used to authenticate WinRM access, and the certificate is injected into
the VM using Azure Key Vault. Access to the key vault is protected by an ACL associated with the SPN's ObjectID.
Linux does not need nor use a key vault, so there's no need to know the ObjectID.
You will also need the following.
@ -38,7 +108,8 @@ for more details.
* Owner
> NOTE: the Owner role is too powerful, and more explicit set of roles
> is TBD. Issue #183 is tracking this work.
> is TBD. Issue #183 is tracking this work. Permissions can be scoped to
> a specific resource group to further limit access.
### Sample Ubuntu
@ -65,26 +136,30 @@ Azure for ARM builder.
"subscription_id": "{{user `sid`}}",
"tenant_id": "{{user `tid`}}",
"capture_container_name": "images",
"capture_name_prefix": "my_prefix",
"resource_group_name": "{{user `rgn`}}",
"storage_account": "{{user `sa`}}",
"capture_container_name": "images",
"capture_name_prefix": "packer",
"os_type": "Linux",
"image_publisher": "Canonical",
"image_offer": "UbuntuServer",
"image_sku": "14.04.3-LTS",
"location": "South Central US",
"resource_group_name": "{{user `rgn`}}",
"storage_account": "{{user `sa`}}",
"vm_size": "Standard_A1"
"vm_size": "Standard_A2"
}
],
"provisioners": [
{
"execute_command": "chmod +x {{ .Path }}; {{ .Vars }} sudo -E sh '{{ .Path }}'",
"inline": [
"sudo apt-get update",
"apt-get update",
"apt-get upgrade -y",
"/usr/sbin/waagent -force -deprovision+user && export HISTSIZE=0 && sync"
],
"inline_shebang": "/bin/sh -x",
"type": "shell"

View File

@ -1,36 +1,112 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
import (
"bytes"
"fmt"
"net/url"
"path"
"strings"
)
const (
BuilderId = "Azure.ResourceManagement.VMImage"
)
type artifact struct {
name string
type Artifact struct {
StorageAccountLocation string
OSDiskUri string
TemplateUri string
OSDiskUriReadOnlySas string
TemplateUriReadOnlySas string
}
func (*artifact) BuilderId() string {
func NewArtifact(template *CaptureTemplate, getSasUrl func(name string) string) (*Artifact, error) {
if template == nil {
return nil, fmt.Errorf("nil capture template")
}
if len(template.Resources) != 1 {
return nil, fmt.Errorf("malformed capture template, expected one resource")
}
vhdUri, err := url.Parse(template.Resources[0].Properties.StorageProfile.OSDisk.Image.Uri)
if err != nil {
return nil, err
}
templateUri, err := storageUriToTemplateUri(vhdUri)
if err != nil {
return nil, err
}
return &Artifact{
OSDiskUri: vhdUri.String(),
OSDiskUriReadOnlySas: getSasUrl(getStorageUrlPath(vhdUri)),
TemplateUri: templateUri.String(),
TemplateUriReadOnlySas: getSasUrl(getStorageUrlPath(templateUri)),
StorageAccountLocation: template.Resources[0].Location,
}, nil
}
func getStorageUrlPath(u *url.URL) string {
parts := strings.Split(u.Path, "/")
return strings.Join(parts[3:], "/")
}
func storageUriToTemplateUri(su *url.URL) (*url.URL, error) {
// packer-osDisk.4085bb15-3644-4641-b9cd-f575918640b4.vhd -> 4085bb15-3644-4641-b9cd-f575918640b4
filename := path.Base(su.Path)
parts := strings.Split(filename, ".")
if len(parts) < 3 {
return nil, fmt.Errorf("malformed URL")
}
// packer-osDisk.4085bb15-3644-4641-b9cd-f575918640b4.vhd -> packer
prefixParts := strings.Split(parts[0], "-")
prefix := strings.Join(prefixParts[:len(prefixParts)-1], "-")
templateFilename := fmt.Sprintf("%s-vmTemplate.%s.json", prefix, parts[1])
// https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/packer-osDisk.4085bb15-3644-4641-b9cd-f575918640b4.vhd"
// ->
// https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/packer-vmTemplate.4085bb15-3644-4641-b9cd-f575918640b4.json"
return url.Parse(strings.Replace(su.String(), filename, templateFilename, 1))
}
func (*Artifact) BuilderId() string {
return BuilderId
}
func (*artifact) Files() []string {
func (*Artifact) Files() []string {
return []string{}
}
func (*artifact) Id() string {
func (*Artifact) Id() string {
return ""
}
func (*artifact) State(name string) interface{} {
func (*Artifact) State(name string) interface{} {
return nil
}
func (*artifact) String() string {
return "{}"
func (a *Artifact) String() string {
var buf bytes.Buffer
buf.WriteString(fmt.Sprintf("%s:\n\n", a.BuilderId()))
buf.WriteString(fmt.Sprintf("StorageAccountLocation: %s\n", a.StorageAccountLocation))
buf.WriteString(fmt.Sprintf("OSDiskUri: %s\n", a.OSDiskUri))
buf.WriteString(fmt.Sprintf("OSDiskUriReadOnlySas: %s\n", a.OSDiskUriReadOnlySas))
buf.WriteString(fmt.Sprintf("TemplateUri: %s\n", a.TemplateUri))
buf.WriteString(fmt.Sprintf("TemplateUriReadOnlySas: %s\n", a.TemplateUriReadOnlySas))
return buf.String()
}
func (*artifact) Destroy() error {
func (*Artifact) Destroy() error {
return nil
}

View File

@ -0,0 +1,155 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
import (
"fmt"
"strings"
"testing"
)
func getFakeSasUrl(name string) string {
return fmt.Sprintf("SAS-%s", name)
}
func TestArtifactString(t *testing.T) {
template := CaptureTemplate{
Resources: []CaptureResources{
CaptureResources{
Properties: CaptureProperties{
StorageProfile: CaptureStorageProfile{
OSDisk: CaptureDisk{
Image: CaptureUri{
Uri: "https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/packer-osDisk.4085bb15-3644-4641-b9cd-f575918640b4.vhd",
},
},
},
},
Location: "southcentralus",
},
},
}
artifact, err := NewArtifact(&template, getFakeSasUrl)
if err != nil {
t.Fatalf("err=%s", err)
}
testSubject := artifact.String()
if !strings.Contains(testSubject, "OSDiskUri: https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/packer-osDisk.4085bb15-3644-4641-b9cd-f575918640b4.vhd") {
t.Errorf("Expected String() output to contain OSDiskUri")
}
if !strings.Contains(testSubject, "OSDiskUriReadOnlySas: SAS-Images/images/packer-osDisk.4085bb15-3644-4641-b9cd-f575918640b4.vhd") {
t.Errorf("Expected String() output to contain OSDiskUriReadOnlySas")
}
if !strings.Contains(testSubject, "TemplateUri: https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/packer-vmTemplate.4085bb15-3644-4641-b9cd-f575918640b4.json") {
t.Errorf("Expected String() output to contain TemplateUri")
}
if !strings.Contains(testSubject, "TemplateUriReadOnlySas: SAS-Images/images/packer-vmTemplate.4085bb15-3644-4641-b9cd-f575918640b4.json") {
t.Errorf("Expected String() output to contain TemplateUriReadOnlySas")
}
if !strings.Contains(testSubject, "StorageAccountLocation: southcentralus") {
t.Errorf("Expected String() output to contain StorageAccountLocation")
}
}
func TestArtifactProperties(t *testing.T) {
template := CaptureTemplate{
Resources: []CaptureResources{
CaptureResources{
Properties: CaptureProperties{
StorageProfile: CaptureStorageProfile{
OSDisk: CaptureDisk{
Image: CaptureUri{
Uri: "https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/packer-osDisk.4085bb15-3644-4641-b9cd-f575918640b4.vhd",
},
},
},
},
Location: "southcentralus",
},
},
}
testSubject, err := NewArtifact(&template, getFakeSasUrl)
if err != nil {
t.Fatalf("err=%s", err)
}
if testSubject.OSDiskUri != "https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/packer-osDisk.4085bb15-3644-4641-b9cd-f575918640b4.vhd" {
t.Errorf("Expected template to be 'https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/packer-osDisk.4085bb15-3644-4641-b9cd-f575918640b4.vhd', but got %s", testSubject.OSDiskUri)
}
if testSubject.OSDiskUriReadOnlySas != "SAS-Images/images/packer-osDisk.4085bb15-3644-4641-b9cd-f575918640b4.vhd" {
t.Errorf("Expected template to be 'SAS-Images/images/packer-osDisk.4085bb15-3644-4641-b9cd-f575918640b4.vhd', but got %s", testSubject.OSDiskUriReadOnlySas)
}
if testSubject.TemplateUri != "https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/packer-vmTemplate.4085bb15-3644-4641-b9cd-f575918640b4.json" {
t.Errorf("Expected template to be 'https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/packer-vmTemplate.4085bb15-3644-4641-b9cd-f575918640b4.json', but got %s", testSubject.TemplateUri)
}
if testSubject.TemplateUriReadOnlySas != "SAS-Images/images/packer-vmTemplate.4085bb15-3644-4641-b9cd-f575918640b4.json" {
t.Errorf("Expected template to be 'SAS-Images/images/packer-vmTemplate.4085bb15-3644-4641-b9cd-f575918640b4.json', but got %s", testSubject.TemplateUriReadOnlySas)
}
if testSubject.StorageAccountLocation != "southcentralus" {
t.Errorf("Expected StorageAccountLocation to be 'southcentral', but got %s", testSubject.StorageAccountLocation)
}
}
func TestArtifactOverHypenatedCaptureUri(t *testing.T) {
template := CaptureTemplate{
Resources: []CaptureResources{
CaptureResources{
Properties: CaptureProperties{
StorageProfile: CaptureStorageProfile{
OSDisk: CaptureDisk{
Image: CaptureUri{
Uri: "https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/pac-ker-osDisk.4085bb15-3644-4641-b9cd-f575918640b4.vhd",
},
},
},
},
Location: "southcentralus",
},
},
}
testSubject, err := NewArtifact(&template, getFakeSasUrl)
if err != nil {
t.Fatalf("err=%s", err)
}
if testSubject.TemplateUri != "https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/pac-ker-vmTemplate.4085bb15-3644-4641-b9cd-f575918640b4.json" {
t.Errorf("Expected template to be 'https://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/pac-ker-vmTemplate.4085bb15-3644-4641-b9cd-f575918640b4.json', but got %s", testSubject.TemplateUri)
}
}
func TestArtifactRejectMalformedTemplates(t *testing.T) {
template := CaptureTemplate{}
_, err := NewArtifact(&template, getFakeSasUrl)
if err == nil {
t.Fatalf("Expected artifact creation to fail, but it succeeded.")
}
}
func TestArtifactRejectMalformedStorageUri(t *testing.T) {
template := CaptureTemplate{
Resources: []CaptureResources{
CaptureResources{
Properties: CaptureProperties{
StorageProfile: CaptureStorageProfile{
OSDisk: CaptureDisk{
Image: CaptureUri{
Uri: "bark",
},
},
},
},
},
},
}
_, err := NewArtifact(&template, getFakeSasUrl)
if err == nil {
t.Fatalf("Expected artifact creation to fail, but it succeeded.")
}
}

View File

@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
import "github.com/Azure/go-autorest/autorest/azure"
type Authenticate struct {
env azure.Environment
clientID string
clientSecret string
tenantID string
}
func NewAuthenticate(env azure.Environment, clientID, clientSecret, tenantID string) *Authenticate {
return &Authenticate{
env: env,
clientID: clientID,
clientSecret: clientSecret,
tenantID: tenantID,
}
}
func (a *Authenticate) getServicePrincipalToken() (*azure.ServicePrincipalToken, error) {
return a.getServicePrincipalTokenWithResource(a.env.ResourceManagerEndpoint)
}
func (a *Authenticate) getServicePrincipalTokenWithResource(resource string) (*azure.ServicePrincipalToken, error) {
oauthConfig, err := newOAuthConfigWithTenant(a.tenantID)
if err != nil {
return nil, err
}
spt, err := azure.NewServicePrincipalToken(
*oauthConfig,
a.clientID,
a.clientSecret,
resource)
return spt, err
}
func newOAuthConfigWithTenant(tenantID string) (*azure.OAuthConfig, error) {
return azure.PublicCloud.OAuthConfigForTenant(tenantID)
}

View File

@ -0,0 +1,39 @@
package arm
import (
"github.com/Azure/go-autorest/autorest/azure"
"testing"
)
// Behavior is the most important thing to assert for ServicePrincipalToken, but
// that cannot be done in a unit test because it involves network access. Instead,
// I assert the expected intertness of this class.
func TestNewAuthenticate(t *testing.T) {
testSubject := NewAuthenticate(azure.PublicCloud, "clientID", "clientString", "tenantID")
spn, err := testSubject.getServicePrincipalToken()
if err != nil {
t.Fatalf(err.Error())
}
if spn.Token.AccessToken != "" {
t.Errorf("spn.Token.AccessToken: expected=\"\", actual=%s", spn.Token.AccessToken)
}
if spn.Token.RefreshToken != "" {
t.Errorf("spn.Token.RefreshToken: expected=\"\", actual=%s", spn.Token.RefreshToken)
}
if spn.Token.ExpiresIn != "" {
t.Errorf("spn.Token.ExpiresIn: expected=\"\", actual=%s", spn.Token.ExpiresIn)
}
if spn.Token.ExpiresOn != "" {
t.Errorf("spn.Token.ExpiresOn: expected=\"\", actual=%s", spn.Token.ExpiresOn)
}
if spn.Token.NotBefore != "" {
t.Errorf("spn.Token.NotBefore: expected=\"\", actual=%s", spn.Token.NotBefore)
}
if spn.Token.Resource != "" {
t.Errorf("spn.Token.Resource: expected=\"\", actual=%s", spn.Token.Resource)
}
if spn.Token.Type != "" {
t.Errorf("spn.Token.Type: expected=\"\", actual=%s", spn.Token.Type)
}
}

View File

@ -1,45 +1,150 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
import (
"encoding/json"
"fmt"
"math"
"net/http"
"os"
"strconv"
"github.com/Azure/azure-sdk-for-go/arm/compute"
"github.com/Azure/azure-sdk-for-go/arm/network"
"github.com/Azure/azure-sdk-for-go/arm/resources/resources"
armStorage "github.com/Azure/azure-sdk-for-go/arm/storage"
"github.com/Azure/azure-sdk-for-go/storage"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/mitchellh/packer/builder/azure/common"
"github.com/mitchellh/packer/version"
)
const (
EnvPackerLogAzureMaxLen = "PACKER_LOG_AZURE_MAXLEN"
)
var (
packerUserAgent = fmt.Sprintf(";packer/%s", version.FormattedVersion())
)
type AzureClient struct {
compute.VirtualMachinesClient
network.PublicIPAddressesClient
resources.GroupsClient
resources.DeploymentsClient
storage.BlobStorageClient
resources.DeploymentsClient
resources.GroupsClient
network.PublicIPAddressesClient
compute.VirtualMachinesClient
common.VaultClient
armStorage.AccountsClient
InspectorMaxLength int
Template *CaptureTemplate
}
func NewAzureClient(subscriptionID string, resourceGroupName string, storageAccountName string, servicePrincipalToken *azure.ServicePrincipalToken) (*AzureClient, error) {
func getCaptureResponse(body string) *CaptureTemplate {
var operation CaptureOperation
err := json.Unmarshal([]byte(body), &operation)
if err != nil {
return nil
}
if operation.Properties != nil && operation.Properties.Output != nil {
return operation.Properties.Output
}
return nil
}
// HACK(chrboum): This method is a hack. It was written to work around this issue
// (https://github.com/Azure/azure-sdk-for-go/issues/307) and to an extent this
// issue (https://github.com/Azure/azure-rest-api-specs/issues/188).
//
// Capturing a VM is a long running operation that requires polling. There are
// couple different forms of polling, and the end result of a poll operation is
// discarded by the SDK. It is expected that any discarded data can be re-fetched,
// so discarding it has minimal impact. Unfortunately, there is no way to re-fetch
// the template returned by a capture call that I am aware of.
//
// If the second issue were fixed the VM ID would be included when GET'ing a VM. The
// VM ID could be used to locate the captured VHD, and captured template.
// Unfortunately, the VM ID is not included so this method cannot be used either.
//
// This code captures the template and saves it to the client (the AzureClient type).
// It expects that the capture API is called only once, or rather you only care that the
// last call's value is important because subsequent requests are not persisted. There
// is no care given to multiple threads writing this value because for our use case
// it does not matter.
func templateCapture(client *AzureClient) autorest.RespondDecorator {
return func(r autorest.Responder) autorest.Responder {
return autorest.ResponderFunc(func(resp *http.Response) error {
body, bodyString := handleBody(resp.Body, math.MaxInt64)
resp.Body = body
captureTemplate := getCaptureResponse(bodyString)
if captureTemplate != nil {
client.Template = captureTemplate
}
return r.Respond(resp)
})
}
}
// WAITING(chrboum): I have logged https://github.com/Azure/azure-sdk-for-go/issues/311 to get this
// method included in the SDK. It has been accepted, and I'll cut over to the official way
// once it ships.
func byConcatDecorators(decorators ...autorest.RespondDecorator) autorest.RespondDecorator {
return func(r autorest.Responder) autorest.Responder {
return autorest.DecorateResponder(r, decorators...)
}
}
func NewAzureClient(subscriptionID, resourceGroupName, storageAccountName string,
servicePrincipalToken, servicePrincipalTokenVault *azure.ServicePrincipalToken) (*AzureClient, error) {
var azureClient = &AzureClient{}
maxlen := getInspectorMaxLength()
azureClient.DeploymentsClient = resources.NewDeploymentsClient(subscriptionID)
azureClient.DeploymentsClient.Authorizer = servicePrincipalToken
azureClient.DeploymentsClient.RequestInspector = withInspection(maxlen)
azureClient.DeploymentsClient.ResponseInspector = byInspecting(maxlen)
azureClient.DeploymentsClient.UserAgent += packerUserAgent
azureClient.GroupsClient = resources.NewGroupsClient(subscriptionID)
azureClient.GroupsClient.Authorizer = servicePrincipalToken
azureClient.GroupsClient.RequestInspector = withInspection(maxlen)
azureClient.GroupsClient.ResponseInspector = byInspecting(maxlen)
azureClient.GroupsClient.UserAgent += packerUserAgent
azureClient.PublicIPAddressesClient = network.NewPublicIPAddressesClient(subscriptionID)
azureClient.PublicIPAddressesClient.Authorizer = servicePrincipalToken
azureClient.PublicIPAddressesClient.RequestInspector = withInspection(maxlen)
azureClient.PublicIPAddressesClient.ResponseInspector = byInspecting(maxlen)
azureClient.PublicIPAddressesClient.UserAgent += packerUserAgent
azureClient.VirtualMachinesClient = compute.NewVirtualMachinesClient(subscriptionID)
azureClient.VirtualMachinesClient.Authorizer = servicePrincipalToken
azureClient.VirtualMachinesClient.RequestInspector = withInspection(maxlen)
azureClient.VirtualMachinesClient.ResponseInspector = byConcatDecorators(byInspecting(maxlen), templateCapture(azureClient))
azureClient.VirtualMachinesClient.UserAgent += packerUserAgent
storageAccountsClient := armStorage.NewAccountsClient(subscriptionID)
storageAccountsClient.Authorizer = servicePrincipalToken
azureClient.AccountsClient = armStorage.NewAccountsClient(subscriptionID)
azureClient.AccountsClient.Authorizer = servicePrincipalToken
azureClient.AccountsClient.RequestInspector = withInspection(maxlen)
azureClient.AccountsClient.ResponseInspector = byInspecting(maxlen)
azureClient.AccountsClient.UserAgent += packerUserAgent
accountKeys, err := storageAccountsClient.ListKeys(resourceGroupName, storageAccountName)
azureClient.VaultClient = common.VaultClient{}
azureClient.VaultClient.Authorizer = servicePrincipalTokenVault
azureClient.VaultClient.RequestInspector = withInspection(maxlen)
azureClient.VaultClient.ResponseInspector = byInspecting(maxlen)
azureClient.VaultClient.UserAgent += packerUserAgent
accountKeys, err := azureClient.AccountsClient.ListKeys(resourceGroupName, storageAccountName)
if err != nil {
return nil, err
}
@ -52,3 +157,21 @@ func NewAzureClient(subscriptionID string, resourceGroupName string, storageAcco
azureClient.BlobStorageClient = storageClient.GetBlobService()
return azureClient, nil
}
func getInspectorMaxLength() int64 {
value, ok := os.LookupEnv(EnvPackerLogAzureMaxLen)
if !ok {
return math.MaxInt64
}
i, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return 0
}
if i < 0 {
return 0
}
return i
}

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
@ -7,12 +7,15 @@ import (
"errors"
"fmt"
"log"
"time"
packerAzureCommon "github.com/mitchellh/packer/builder/azure/common"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/packer/builder/azure/common/lin"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/common"
"github.com/mitchellh/packer/helper/communicator"
@ -27,6 +30,9 @@ type Builder struct {
const (
DefaultPublicIPAddressName = "packerPublicIP"
DefaultSasBlobContainer = "system/Microsoft.Compute"
DefaultSasBlobPermission = "r"
DefaultSecretName = "packerKeyVaultSecret"
)
func (b *Builder) Prepare(raws ...interface{}) ([]string, error) {
@ -38,10 +44,8 @@ func (b *Builder) Prepare(raws ...interface{}) ([]string, error) {
b.config = c
b.stateBag = new(multistep.BasicStateBag)
err := b.configureStateBag(b.stateBag)
if err != nil {
return nil, err
}
b.configureStateBag(b.stateBag)
b.setTemplateParameters(b.stateBag)
return warnings, errs
}
@ -52,33 +56,81 @@ func (b *Builder) Run(ui packer.Ui, hook packer.Hook, cache packer.Cache) (packe
b.stateBag.Put("hook", hook)
b.stateBag.Put(constants.Ui, ui)
servicePrincipalToken, err := b.createServicePrincipalToken()
spnCloud, spnKeyVault, err := b.getServicePrincipalTokens(ui.Say)
if err != nil {
return nil, err
}
ui.Message("Creating Azure Resource Manager (ARM) client ...")
azureClient, err := NewAzureClient(b.config.SubscriptionID, b.config.ResourceGroupName, b.config.StorageAccount, servicePrincipalToken)
azureClient, err := NewAzureClient(
b.config.SubscriptionID,
b.config.ResourceGroupName,
b.config.StorageAccount,
spnCloud,
spnKeyVault)
if err != nil {
return nil, err
}
steps := []multistep.Step{
NewStepCreateResourceGroup(azureClient, ui),
NewStepValidateTemplate(azureClient, ui),
NewStepDeployTemplate(azureClient, ui),
NewStepGetIPAddress(azureClient, ui),
&communicator.StepConnectSSH{
Config: &b.config.Comm,
Host: lin.SSHHost,
SSHConfig: lin.SSHConfig(b.config.UserName),
},
&common.StepProvision{},
NewStepGetOSDisk(azureClient, ui),
NewStepPowerOffCompute(azureClient, ui),
NewStepCaptureImage(azureClient, ui),
NewStepDeleteResourceGroup(azureClient, ui),
NewStepDeleteOSDisk(azureClient, ui),
b.config.storageAccountBlobEndpoint, err = b.getBlobEndpoint(azureClient, b.config.ResourceGroupName, b.config.StorageAccount)
if err != nil {
return nil, err
}
b.setTemplateParameters(b.stateBag)
var steps []multistep.Step
if b.config.OSType == constants.Target_Linux {
steps = []multistep.Step{
NewStepCreateResourceGroup(azureClient, ui),
NewStepValidateTemplate(azureClient, ui, Linux),
NewStepDeployTemplate(azureClient, ui, Linux),
NewStepGetIPAddress(azureClient, ui),
&communicator.StepConnectSSH{
Config: &b.config.Comm,
Host: lin.SSHHost,
SSHConfig: lin.SSHConfig(b.config.UserName),
},
&common.StepProvision{},
NewStepGetOSDisk(azureClient, ui),
NewStepPowerOffCompute(azureClient, ui),
NewStepCaptureImage(azureClient, ui),
NewStepDeleteResourceGroup(azureClient, ui),
NewStepDeleteOSDisk(azureClient, ui),
}
} else if b.config.OSType == constants.Target_Windows {
steps = []multistep.Step{
NewStepCreateResourceGroup(azureClient, ui),
NewStepValidateTemplate(azureClient, ui, KeyVault),
NewStepDeployTemplate(azureClient, ui, KeyVault),
NewStepGetCertificate(azureClient, ui),
NewStepSetCertificate(b.config, ui),
NewStepValidateTemplate(azureClient, ui, Windows),
NewStepDeployTemplate(azureClient, ui, Windows),
NewStepGetIPAddress(azureClient, ui),
&communicator.StepConnectWinRM{
Config: &b.config.Comm,
Host: func(stateBag multistep.StateBag) (string, error) {
return stateBag.Get(constants.SSHHost).(string), nil
},
WinRMConfig: func(multistep.StateBag) (*communicator.WinRMConfig, error) {
return &communicator.WinRMConfig{
Username: b.config.UserName,
Password: b.config.tmpAdminPassword,
}, nil
},
},
&common.StepProvision{},
NewStepGetOSDisk(azureClient, ui),
NewStepPowerOffCompute(azureClient, ui),
NewStepCaptureImage(azureClient, ui),
NewStepDeleteResourceGroup(azureClient, ui),
NewStepDeleteOSDisk(azureClient, ui),
}
} else {
return nil, fmt.Errorf("Builder does not support the os_type '%s'", b.config.OSType)
}
if b.config.PackerDebug {
@ -103,7 +155,17 @@ func (b *Builder) Run(ui packer.Ui, hook packer.Hook, cache packer.Cache) (packe
return nil, errors.New("Build was halted.")
}
return &artifact{}, nil
if template, ok := b.stateBag.GetOk(constants.ArmCaptureTemplate); ok {
return NewArtifact(
template.(*CaptureTemplate),
func(name string) string {
month := time.Now().AddDate(0, 1, 0).UTC()
sasUrl, _ := azureClient.BlobStorageClient.GetBlobSASURI(DefaultSasBlobContainer, name, month, DefaultSasBlobPermission)
return sasUrl
})
}
return &Artifact{}, nil
}
func (b *Builder) Cancel() {
@ -126,33 +188,57 @@ func (b *Builder) createRunner(steps *[]multistep.Step, ui packer.Ui) multistep.
}
}
func (b *Builder) configureStateBag(stateBag multistep.StateBag) error {
func (b *Builder) getBlobEndpoint(client *AzureClient, resourceGroupName string, storageAccountName string) (string, error) {
account, err := client.AccountsClient.GetProperties(resourceGroupName, storageAccountName)
if err != nil {
return "", err
}
return *account.Properties.PrimaryEndpoints.Blob, nil
}
func (b *Builder) configureStateBag(stateBag multistep.StateBag) {
stateBag.Put(constants.AuthorizedKey, b.config.sshAuthorizedKey)
stateBag.Put(constants.PrivateKey, b.config.sshPrivateKey)
stateBag.Put(constants.ArmComputeName, b.config.tmpComputeName)
stateBag.Put(constants.ArmDeploymentName, b.config.tmpDeploymentName)
stateBag.Put(constants.ArmKeyVaultName, b.config.tmpKeyVaultName)
stateBag.Put(constants.ArmLocation, b.config.Location)
stateBag.Put(constants.ArmPublicIPAddressName, DefaultPublicIPAddressName)
stateBag.Put(constants.ArmResourceGroupName, b.config.tmpResourceGroupName)
stateBag.Put(constants.ArmStorageAccountName, b.config.StorageAccount)
}
func (b *Builder) setTemplateParameters(stateBag multistep.StateBag) {
stateBag.Put(constants.ArmTemplateParameters, b.config.toTemplateParameters())
stateBag.Put(constants.ArmVirtualMachineCaptureParameters, b.config.toVirtualMachineCaptureParameters())
stateBag.Put(constants.ArmPublicIPAddressName, DefaultPublicIPAddressName)
return nil
}
func (b *Builder) createServicePrincipalToken() (*azure.ServicePrincipalToken, error) {
oauthConfig, err := azure.PublicCloud.OAuthConfigForTenant(b.config.TenantID)
if err != nil {
return nil, err
func (b *Builder) getServicePrincipalTokens(say func(string)) (*azure.ServicePrincipalToken, *azure.ServicePrincipalToken, error) {
var servicePrincipalToken *azure.ServicePrincipalToken
var servicePrincipalTokenVault *azure.ServicePrincipalToken
var err error
if b.config.useDeviceLogin {
servicePrincipalToken, err = packerAzureCommon.Authenticate(*b.config.cloudEnvironment, b.config.SubscriptionID, say)
if err != nil {
return nil, nil, err
}
} else {
auth := NewAuthenticate(*b.config.cloudEnvironment, b.config.ClientID, b.config.ClientSecret, b.config.TenantID)
servicePrincipalToken, err = auth.getServicePrincipalToken()
if err != nil {
return nil, nil, err
}
servicePrincipalTokenVault, err = auth.getServicePrincipalTokenWithResource(packerAzureCommon.AzureVaultScope)
if err != nil {
return nil, nil, err
}
}
spt, err := azure.NewServicePrincipalToken(
*oauthConfig,
b.config.ClientID,
b.config.ClientSecret,
azure.PublicCloud.ResourceManagerEndpoint)
return spt, err
return servicePrincipalToken, servicePrincipalTokenVault, nil
}

View File

@ -1,17 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
import (
"testing"
"github.com/mitchellh/packer/builder/azure/common/constants"
"testing"
)
func TestStateBagShouldBePopulatedExpectedValues(t *testing.T) {
var testSubject = &Builder{}
testSubject.Prepare(getArmBuilderConfiguration(), getPackerConfiguration())
_, err := testSubject.Prepare(getArmBuilderConfiguration(), getPackerConfiguration())
if err != nil {
t.Fatalf("failed to prepare: %s", err)
}
var expectedStateBagKeys = []string{
constants.AuthorizedKey,
@ -21,6 +23,7 @@ func TestStateBagShouldBePopulatedExpectedValues(t *testing.T) {
constants.ArmDeploymentName,
constants.ArmLocation,
constants.ArmResourceGroupName,
constants.ArmStorageAccountName,
constants.ArmTemplateParameters,
constants.ArmVirtualMachineCaptureParameters,
constants.ArmPublicIPAddressName,

View File

@ -0,0 +1,86 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
type CaptureTemplateParameter struct {
Type string `json:"type"`
DefaultValue string `json:"defaultValue,omitempty"`
}
type CaptureHardwareProfile struct {
VMSize string `json:"vmSize"`
}
type CaptureUri struct {
Uri string `json:"uri"`
}
type CaptureDisk struct {
OSType string `json:"osType"`
Name string `json:"name"`
Image CaptureUri `json:"image"`
Vhd CaptureUri `json:"vhd"`
CreateOption string `json:"createOption"`
Caching string `json:"caching"`
}
type CaptureStorageProfile struct {
OSDisk CaptureDisk `json:"osDisk"`
}
type CaptureOSProfile struct {
ComputerName string `json:"computerName"`
AdminUsername string `json:"adminUsername"`
AdminPassword string `json:"adminPassword"`
}
type CaptureNetworkInterface struct {
Id string `json:"id"`
}
type CaptureNetworkProfile struct {
NetworkInterfaces []CaptureNetworkInterface `json:"networkInterfaces"`
}
type CaptureBootDiagnostics struct {
Enabled bool `json:"enabled"`
}
type CaptureDiagnosticProfile struct {
BootDiagnostics CaptureBootDiagnostics `json:"bootDiagnostics"`
}
type CaptureProperties struct {
HardwareProfile CaptureHardwareProfile `json:"hardwareProfile"`
StorageProfile CaptureStorageProfile `json:"storageProfile"`
OSProfile CaptureOSProfile `json:"osProfile"`
NetworkProfile CaptureNetworkProfile `json:"networkProfile"`
DiagnosticsProfile CaptureDiagnosticProfile `json:"diagnosticsProfile"`
ProvisioningState int `json:"provisioningState"`
}
type CaptureResources struct {
ApiVersion string `json:"apiVersion"`
Name string `json:"name"`
Type string `json:"type"`
Location string `json:"location"`
Properties CaptureProperties `json:"properties"`
}
type CaptureTemplate struct {
Schema string `json:"$schema"`
ContentVersion string `json:"contentVersion"`
Parameters map[string]CaptureTemplateParameter `json:"parameters"`
Resources []CaptureResources `json:"resources"`
}
type CaptureOperationProperties struct {
Output *CaptureTemplate `json:"output"`
}
type CaptureOperation struct {
OperationId string `json:"operationId"`
Status string `json:"status"`
Properties *CaptureOperationProperties `json:"properties"`
}

View File

@ -0,0 +1,213 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
import (
"encoding/json"
"testing"
)
var captureTemplate01 = `{
"operationId": "ac1c7c38-a591-41b3-89bd-ea39fceace1b",
"status": "Succeeded",
"startTime": "2016-04-04T21:07:25.2900874+00:00",
"endTime": "2016-04-04T21:07:26.4776321+00:00",
"properties": {
"output": {
"$schema": "http://schema.management.azure.com/schemas/2014-04-01-preview/VM_IP.json",
"contentVersion": "1.0.0.0",
"parameters": {
"vmName": {
"type": "string"
},
"vmSize": {
"type": "string",
"defaultValue": "Standard_A2"
},
"adminUserName": {
"type": "string"
},
"adminPassword": {
"type": "securestring"
},
"networkInterfaceId": {
"type": "string"
}
},
"resources": [
{
"apiVersion": "2015-06-15",
"properties": {
"hardwareProfile": {
"vmSize": "[parameters('vmSize')]"
},
"storageProfile": {
"osDisk": {
"osType": "Linux",
"name": "packer-osDisk.32118633-6dc9-449f-83b6-a7d2983bec14.vhd",
"createOption": "FromImage",
"image": {
"uri": "http://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/packer-osDisk.32118633-6dc9-449f-83b6-a7d2983bec14.vhd"
},
"vhd": {
"uri": "http://storage.blob.core.windows.net/vmcontainerce1a1b75-f480-47cb-8e6e-55142e4a5f68/osDisk.ce1a1b75-f480-47cb-8e6e-55142e4a5f68.vhd"
},
"caching": "ReadWrite"
}
},
"osProfile": {
"computerName": "[parameters('vmName')]",
"adminUsername": "[parameters('adminUsername')]",
"adminPassword": "[parameters('adminPassword')]"
},
"networkProfile": {
"networkInterfaces": [
{
"id": "[parameters('networkInterfaceId')]"
}
]
},
"diagnosticsProfile": {
"bootDiagnostics": {
"enabled": false
}
},
"provisioningState": 0
},
"name": "[parameters('vmName')]",
"type": "Microsoft.Compute/virtualMachines",
"location": "southcentralus"
}
]
}
}
}`
var captureTemplate02 = `{
"operationId": "ac1c7c38-a591-41b3-89bd-ea39fceace1b",
"status": "Succeeded",
"startTime": "2016-04-04T21:07:25.2900874+00:00",
"endTime": "2016-04-04T21:07:26.4776321+00:00"
}`
func TestCaptureParseJson(t *testing.T) {
var operation CaptureOperation
err := json.Unmarshal([]byte(captureTemplate01), &operation)
if err != nil {
t.Fatalf("failed to the sample capture operation: %s", err)
}
testSubject := operation.Properties.Output
if testSubject.Schema != "http://schema.management.azure.com/schemas/2014-04-01-preview/VM_IP.json" {
t.Errorf("Schema's value was unexpected: %s", testSubject.Schema)
}
if testSubject.ContentVersion != "1.0.0.0" {
t.Errorf("ContentVersion's value was unexpected: %s", testSubject.ContentVersion)
}
// == Parameters ====================================
if len(testSubject.Parameters) != 5 {
t.Fatalf("expected parameters to have 5 keys, but got %d", len(testSubject.Parameters))
}
if _, ok := testSubject.Parameters["vmName"]; !ok {
t.Errorf("Parameters['vmName'] was an expected parameters, but it did not exist")
}
if testSubject.Parameters["vmName"].Type != "string" {
t.Errorf("Parameters['vmName'].Type == 'string', but got '%s'", testSubject.Parameters["vmName"].Type)
}
if _, ok := testSubject.Parameters["vmSize"]; !ok {
t.Errorf("Parameters['vmSize'] was an expected parameters, but it did not exist")
}
if testSubject.Parameters["vmSize"].Type != "string" {
t.Errorf("Parameters['vmSize'].Type == 'string', but got '%s'", testSubject.Parameters["vmSize"])
}
if testSubject.Parameters["vmSize"].DefaultValue != "Standard_A2" {
t.Errorf("Parameters['vmSize'].DefaultValue == 'string', but got '%s'", testSubject.Parameters["vmSize"].DefaultValue)
}
// == Resources =====================================
if len(testSubject.Resources) != 1 {
t.Fatalf("expected resources to have length 1, but got %d", len(testSubject.Resources))
}
if testSubject.Resources[0].Name != "[parameters('vmName')]" {
t.Errorf("Resources[0].Name's value was unexpected: %s", testSubject.Resources[0].Name)
}
if testSubject.Resources[0].Type != "Microsoft.Compute/virtualMachines" {
t.Errorf("Resources[0].Type's value was unexpected: %s", testSubject.Resources[0].Type)
}
if testSubject.Resources[0].Location != "southcentralus" {
t.Errorf("Resources[0].Location's value was unexpected: %s", testSubject.Resources[0].Location)
}
// == Resources/Properties =====================================
if testSubject.Resources[0].Properties.ProvisioningState != 0 {
t.Errorf("Resources[0].Properties.ProvisioningState's value was unexpected: %d", testSubject.Resources[0].Properties.ProvisioningState)
}
// == Resources/Properties/HardwareProfile ======================
hardwareProfile := testSubject.Resources[0].Properties.HardwareProfile
if hardwareProfile.VMSize != "[parameters('vmSize')]" {
t.Errorf("Resources[0].Properties.HardwareProfile.VMSize's value was unexpected: %s", hardwareProfile.VMSize)
}
// == Resources/Properties/StorageProfile/OSDisk ================
osDisk := testSubject.Resources[0].Properties.StorageProfile.OSDisk
if osDisk.OSType != "Linux" {
t.Errorf("Resources[0].Properties.StorageProfile.OSDisk.OSDisk's value was unexpected: %s", osDisk.OSType)
}
if osDisk.Name != "packer-osDisk.32118633-6dc9-449f-83b6-a7d2983bec14.vhd" {
t.Errorf("Resources[0].Properties.StorageProfile.OSDisk.Name's value was unexpected: %s", osDisk.Name)
}
if osDisk.CreateOption != "FromImage" {
t.Errorf("Resources[0].Properties.StorageProfile.OSDisk.CreateOption's value was unexpected: %s", osDisk.CreateOption)
}
if osDisk.Image.Uri != "http://storage.blob.core.windows.net/system/Microsoft.Compute/Images/images/packer-osDisk.32118633-6dc9-449f-83b6-a7d2983bec14.vhd" {
t.Errorf("Resources[0].Properties.StorageProfile.OSDisk.Image.Uri's value was unexpected: %s", osDisk.Image.Uri)
}
if osDisk.Vhd.Uri != "http://storage.blob.core.windows.net/vmcontainerce1a1b75-f480-47cb-8e6e-55142e4a5f68/osDisk.ce1a1b75-f480-47cb-8e6e-55142e4a5f68.vhd" {
t.Errorf("Resources[0].Properties.StorageProfile.OSDisk.Vhd.Uri's value was unexpected: %s", osDisk.Vhd.Uri)
}
if osDisk.Caching != "ReadWrite" {
t.Errorf("Resources[0].Properties.StorageProfile.OSDisk.Caching's value was unexpected: %s", osDisk.Caching)
}
// == Resources/Properties/OSProfile ============================
osProfile := testSubject.Resources[0].Properties.OSProfile
if osProfile.AdminPassword != "[parameters('adminPassword')]" {
t.Errorf("Resources[0].Properties.OSProfile.AdminPassword's value was unexpected: %s", osProfile.AdminPassword)
}
if osProfile.AdminUsername != "[parameters('adminUsername')]" {
t.Errorf("Resources[0].Properties.OSProfile.AdminUsername's value was unexpected: %s", osProfile.AdminUsername)
}
if osProfile.ComputerName != "[parameters('vmName')]" {
t.Errorf("Resources[0].Properties.OSProfile.ComputerName's value was unexpected: %s", osProfile.ComputerName)
}
// == Resources/Properties/NetworkProfile =======================
networkProfile := testSubject.Resources[0].Properties.NetworkProfile
if len(networkProfile.NetworkInterfaces) != 1 {
t.Errorf("Count of Resources[0].Properties.NetworkProfile.NetworkInterfaces was expected to be 1, but go %d", len(networkProfile.NetworkInterfaces))
}
if networkProfile.NetworkInterfaces[0].Id != "[parameters('networkInterfaceId')]" {
t.Errorf("Resources[0].Properties.NetworkProfile.NetworkInterfaces[0].Id's value was unexpected: %s", networkProfile.NetworkInterfaces[0].Id)
}
// == Resources/Properties/DiagnosticsProfile ===================
diagnosticsProfile := testSubject.Resources[0].Properties.DiagnosticsProfile
if diagnosticsProfile.BootDiagnostics.Enabled != false {
t.Errorf("Resources[0].Properties.DiagnosticsProfile.BootDiagnostics.Enabled's value was unexpected: %t", diagnosticsProfile.BootDiagnostics.Enabled)
}
}
func TestCaptureEmptyOperationJson(t *testing.T) {
var operation CaptureOperation
err := json.Unmarshal([]byte(captureTemplate02), &operation)
if err != nil {
t.Fatalf("failed to the sample capture operation: %s", err)
}
if operation.Properties != nil {
t.Errorf("JSON contained no properties, but value was not nil: %+v", operation.Properties)
}
}

View File

@ -1,29 +1,43 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"math/big"
"net/http"
"time"
"golang.org/x/crypto/ssh"
"github.com/Azure/azure-sdk-for-go/arm/compute"
"github.com/Azure/go-autorest/autorest/to"
"github.com/Azure/go-ntlmssp"
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/packer/builder/azure/pkcs12"
"github.com/mitchellh/packer/common"
"github.com/mitchellh/packer/helper/communicator"
"github.com/mitchellh/packer/helper/config"
"github.com/mitchellh/packer/packer"
"github.com/mitchellh/packer/template/interpolate"
"github.com/Azure/go-autorest/autorest/azure"
"golang.org/x/crypto/ssh"
"strings"
)
const (
DefaultUserName = "packer"
DefaultVMSize = "Standard_A1"
DefaultCloudEnvironmentName = "Public"
DefaultImageVersion = "latest"
DefaultUserName = "packer"
DefaultVMSize = "Standard_A1"
)
type Config struct {
@ -32,6 +46,7 @@ type Config struct {
// Authentication via OAUTH
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
ObjectID string `mapstructure:"object_id"`
TenantID string `mapstructure:"tenant_id"`
SubscriptionID string `mapstructure:"subscription_id"`
@ -43,46 +58,81 @@ type Config struct {
ImagePublisher string `mapstructure:"image_publisher"`
ImageOffer string `mapstructure:"image_offer"`
ImageSku string `mapstructure:"image_sku"`
ImageVersion string `mapstructure:"image_version"`
Location string `mapstructure:"location"`
VMSize string `mapstructure:"vm_size"`
// Deployment
ResourceGroupName string `mapstructure:"resource_group_name"`
StorageAccount string `mapstructure:"storage_account"`
ResourceGroupName string `mapstructure:"resource_group_name"`
StorageAccount string `mapstructure:"storage_account"`
storageAccountBlobEndpoint string
CloudEnvironmentName string `mapstructure:"cloud_environment_name"`
cloudEnvironment *azure.Environment
// OS
OSType string `mapstructure:"os_type"`
// Runtime Values
UserName string
Password string
tmpAdminPassword string
tmpResourceGroupName string
tmpComputeName string
tmpDeploymentName string
tmpOSDiskName string
UserName string
Password string
tmpAdminPassword string
tmpCertificatePassword string
tmpResourceGroupName string
tmpComputeName string
tmpDeploymentName string
tmpKeyVaultName string
tmpOSDiskName string
tmpWinRMCertificateUrl string
useDeviceLogin bool
// Authentication with the VM via SSH
sshAuthorizedKey string
sshPrivateKey string
// Authentication with the VM via WinRM
winrmCertificate string
Comm communicator.Config `mapstructure:",squash"`
ctx *interpolate.Context
}
type keyVaultCertificate struct {
Data string `json:"data"`
DataType string `json:"dataType"`
Password string `json:"password,omitempty"`
}
// If we ever feel the need to support more templates consider moving this
// method to its own factory class.
func (c *Config) toTemplateParameters() *TemplateParameters {
return &TemplateParameters{
AdminUsername: &TemplateParameter{c.UserName},
AdminPassword: &TemplateParameter{c.Password},
DnsNameForPublicIP: &TemplateParameter{c.tmpComputeName},
ImageOffer: &TemplateParameter{c.ImageOffer},
ImagePublisher: &TemplateParameter{c.ImagePublisher},
ImageSku: &TemplateParameter{c.ImageSku},
OSDiskName: &TemplateParameter{c.tmpOSDiskName},
SshAuthorizedKey: &TemplateParameter{c.sshAuthorizedKey},
StorageAccountName: &TemplateParameter{c.StorageAccount},
VMSize: &TemplateParameter{c.VMSize},
VMName: &TemplateParameter{c.tmpComputeName},
templateParameters := &TemplateParameters{
AdminUsername: &TemplateParameter{c.UserName},
AdminPassword: &TemplateParameter{c.Password},
DnsNameForPublicIP: &TemplateParameter{c.tmpComputeName},
ImageOffer: &TemplateParameter{c.ImageOffer},
ImagePublisher: &TemplateParameter{c.ImagePublisher},
ImageSku: &TemplateParameter{c.ImageSku},
ImageVersion: &TemplateParameter{c.ImageVersion},
OSDiskName: &TemplateParameter{c.tmpOSDiskName},
StorageAccountBlobEndpoint: &TemplateParameter{c.storageAccountBlobEndpoint},
VMSize: &TemplateParameter{c.VMSize},
VMName: &TemplateParameter{c.tmpComputeName},
}
switch c.OSType {
case constants.Target_Linux:
templateParameters.SshAuthorizedKey = &TemplateParameter{c.sshAuthorizedKey}
case constants.Target_Windows:
templateParameters.TenantId = &TemplateParameter{c.TenantID}
templateParameters.ObjectId = &TemplateParameter{c.ObjectID}
templateParameters.KeyVaultName = &TemplateParameter{c.tmpKeyVaultName}
templateParameters.KeyVaultSecretValue = &TemplateParameter{c.winrmCertificate}
templateParameters.WinRMCertificateUrl = &TemplateParameter{c.tmpWinRMCertificateUrl}
}
return templateParameters
}
func (c *Config) toVirtualMachineCaptureParameters() *compute.VirtualMachineCaptureParameters {
@ -93,6 +143,66 @@ func (c *Config) toVirtualMachineCaptureParameters() *compute.VirtualMachineCapt
}
}
func (c *Config) createCertificate() (string, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
err := fmt.Errorf("Failed to Generate Private Key: %s", err)
return "", err
}
host := fmt.Sprintf("%s.cloudapp.net", c.tmpComputeName)
notBefore := time.Now()
notAfter := notBefore.Add(365 * 24 * time.Hour)
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
err := fmt.Errorf("Failed to Generate Serial Number: %v", err)
return "", err
}
template := x509.Certificate{
SerialNumber: serialNumber,
Issuer: pkix.Name{
CommonName: host,
},
Subject: pkix.Name{
CommonName: host,
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
err = fmt.Errorf("Failed to Create Certificate: %s", err)
return "", err
}
pfxBytes, err := pkcs12.Encode(derBytes, privateKey, c.tmpCertificatePassword)
if err != nil {
err = fmt.Errorf("Failed to encode certificate as PFX: %s", err)
return "", err
}
keyVaultDescription := keyVaultCertificate{
Data: base64.StdEncoding.EncodeToString(pfxBytes),
DataType: "pfx",
Password: c.tmpCertificatePassword,
}
bytes, err := json.Marshal(keyVaultDescription)
if err != nil {
err = fmt.Errorf("Failed to marshal key vault description: %s", err)
return "", err
}
return base64.StdEncoding.EncodeToString(bytes), nil
}
func newConfig(raws ...interface{}) (*Config, []string, error) {
var c Config
@ -108,12 +218,21 @@ func newConfig(raws ...interface{}) (*Config, []string, error) {
provideDefaultValues(&c)
setRuntimeValues(&c)
setUserNamePassword(&c)
err = setCloudEnvironment(&c)
if err != nil {
return nil, nil, err
}
err = setSshValues(&c)
if err != nil {
return nil, nil, err
}
err = setWinRMCertificate(&c)
if err != nil {
return nil, nil, err
}
var errs *packer.MultiError
errs = packer.MultiErrorAppend(errs, c.Comm.Prepare(c.ctx)...)
@ -157,18 +276,30 @@ func setSshValues(c *Config) error {
c.sshPrivateKey = sshKeyPair.PrivateKey()
}
c.Comm.WinRMTransportDecorator = func(t *http.Transport) http.RoundTripper {
return &ntlmssp.Negotiator{RoundTripper: t}
}
return nil
}
func setWinRMCertificate(c *Config) error {
cert, err := c.createCertificate()
c.winrmCertificate = cert
return err
}
func setRuntimeValues(c *Config) {
var tempName = NewTempName()
c.tmpAdminPassword = tempName.AdminPassword
c.tmpCertificatePassword = tempName.CertificatePassword
c.tmpComputeName = tempName.ComputeName
c.tmpDeploymentName = tempName.DeploymentName
// c.tmpResourceGroupName = c.ResourceGroupName
c.tmpResourceGroupName = tempName.ResourceGroupName
c.tmpOSDiskName = tempName.OSDiskName
c.tmpKeyVaultName = tempName.KeyVaultName
}
func setUserNamePassword(c *Config) {
@ -185,30 +316,77 @@ func setUserNamePassword(c *Config) {
}
}
func setCloudEnvironment(c *Config) error {
name := strings.ToUpper(c.CloudEnvironmentName)
switch name {
case "CHINA", "CHINACLOUD", "AZURECHINACLOUD":
c.cloudEnvironment = &azure.ChinaCloud
case "PUBLIC", "PUBLICCLOUD", "AZUREPUBLICCLOUD":
c.cloudEnvironment = &azure.PublicCloud
case "USGOVERNMENT", "USGOVERNMENTCLOUD", "AZUREUSGOVERNMENTCLOUD":
c.cloudEnvironment = &azure.USGovernmentCloud
default:
return fmt.Errorf("There is no cloud envionment matching the name '%s'!", c.CloudEnvironmentName)
}
return nil
}
func provideDefaultValues(c *Config) {
if c.VMSize == "" {
c.VMSize = DefaultVMSize
}
if c.ImageVersion == "" {
c.ImageVersion = DefaultImageVersion
}
if c.CloudEnvironmentName == "" {
c.CloudEnvironmentName = DefaultCloudEnvironmentName
}
}
func assertRequiredParametersSet(c *Config, errs *packer.MultiError) {
/////////////////////////////////////////////
// Authentication via OAUTH
if c.ClientID == "" {
errs = packer.MultiErrorAppend(errs, fmt.Errorf("A client_id must be specified"))
// Check if device login is being asked for, and is allowed.
//
// Device login is enabled if the user only defines SubscriptionID and not
// ClientID, ClientSecret, and TenantID.
//
// Device login is not enabled for Windows because the WinRM certificate is
// readable by the ObjectID of the App. There may be another way to handle
// this case, but I am not currently aware of it - send feedback.
isUseDeviceLogin := func(c *Config) bool {
if c.OSType == constants.Target_Windows {
return false
}
return c.SubscriptionID != "" &&
c.ClientID == "" &&
c.ClientSecret == "" &&
c.TenantID == ""
}
if c.ClientSecret == "" {
errs = packer.MultiErrorAppend(errs, fmt.Errorf("A client_secret must be specified"))
}
if isUseDeviceLogin(c) {
c.useDeviceLogin = true
} else {
if c.ClientID == "" {
errs = packer.MultiErrorAppend(errs, fmt.Errorf("A client_id must be specified"))
}
if c.TenantID == "" {
errs = packer.MultiErrorAppend(errs, fmt.Errorf("A tenant_id must be specified"))
}
if c.ClientSecret == "" {
errs = packer.MultiErrorAppend(errs, fmt.Errorf("A client_secret must be specified"))
}
if c.SubscriptionID == "" {
errs = packer.MultiErrorAppend(errs, fmt.Errorf("A subscription_id must be specified"))
if c.TenantID == "" {
errs = packer.MultiErrorAppend(errs, fmt.Errorf("A tenant_id must be specified"))
}
if c.SubscriptionID == "" {
errs = packer.MultiErrorAppend(errs, fmt.Errorf("A subscription_id must be specified"))
}
}
/////////////////////////////////////////////
@ -223,7 +401,6 @@ func assertRequiredParametersSet(c *Config, errs *packer.MultiError) {
/////////////////////////////////////////////
// Compute
if c.ImagePublisher == "" {
errs = packer.MultiErrorAppend(errs, fmt.Errorf("A image_publisher must be specified"))
}
@ -242,8 +419,13 @@ func assertRequiredParametersSet(c *Config, errs *packer.MultiError) {
/////////////////////////////////////////////
// Deployment
if c.StorageAccount == "" {
errs = packer.MultiErrorAppend(errs, fmt.Errorf("A storage_account must be specified"))
}
/////////////////////////////////////////////
// OS
if c.OSType != constants.Target_Linux && c.OSType != constants.Target_Windows {
errs = packer.MultiErrorAppend(errs, fmt.Errorf("An os_type must be specified"))
}
}

View File

@ -1,14 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
import (
"bytes"
"encoding/json"
"fmt"
"strings"
"testing"
"time"
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/packer/packer"
)
// List of configuration parameters that are required by the ARM builder.
@ -21,6 +24,7 @@ var requiredConfigValues = []string{
"image_publisher",
"image_sku",
"location",
"os_type",
"storage_account",
"subscription_id",
"tenant_id",
@ -41,21 +45,23 @@ func TestConfigShouldProvideReasonableDefaultValues(t *testing.T) {
if c.VMSize == "" {
t.Errorf("Expected 'VMSize' to be populated, but it was empty!")
}
if c.ObjectID != "" {
t.Errorf("Expected 'ObjectID' to be nil, but it was '%s'!", c.ObjectID)
}
}
func TestConfigShouldBeAbleToOverrideDefaultedValues(t *testing.T) {
builderValues := make(map[string]string)
// Populate the dictionary with all of the required values.
for _, v := range requiredConfigValues {
builderValues[v] = "--some-value--"
}
builderValues := getArmBuilderConfiguration()
builderValues["ssh_password"] = "override_password"
builderValues["ssh_username"] = "override_username"
builderValues["vm_size"] = "override_vm_size"
c, _, _ := newConfig(getArmBuilderConfigurationFromMap(builderValues), getPackerConfiguration())
c, _, err := newConfig(builderValues, getPackerConfiguration())
if err != nil {
t.Fatalf("newConfig failed: %s", err)
}
if c.Password != "override_password" {
t.Errorf("Expected 'Password' to be set to 'override_password', but found '%s'!", c.Password)
@ -74,7 +80,7 @@ func TestConfigShouldBeAbleToOverrideDefaultedValues(t *testing.T) {
}
if c.VMSize != "override_vm_size" {
t.Errorf("Expected 'vm_size' to be set to 'override_username', but found '%s'!", c.VMSize)
t.Errorf("Expected 'vm_size' to be set to 'override_vm_size', but found '%s'!", c.VMSize)
}
}
@ -86,16 +92,77 @@ func TestConfigShouldDefaultVMSizeToStandardA1(t *testing.T) {
}
}
func TestUserShouldProvideRequiredValues(t *testing.T) {
builderValues := make(map[string]string)
func TestConfigShouldDefaultImageVersionToLatest(t *testing.T) {
c, _, _ := newConfig(getArmBuilderConfiguration(), getPackerConfiguration())
// Populate the dictionary with all of the required values.
for _, v := range requiredConfigValues {
builderValues[v] = "--some-value--"
if c.ImageVersion != "latest" {
t.Errorf("Expected 'ImageVersion' to default to 'latest', but got '%s'.", c.ImageVersion)
}
}
func TestConfigShouldDefaultToPublicCloud(t *testing.T) {
c, _, _ := newConfig(getArmBuilderConfiguration(), getPackerConfiguration())
if c.CloudEnvironmentName != "Public" {
t.Errorf("Expected 'CloudEnvironmentName' to default to 'Public', but got '%s'.", c.CloudEnvironmentName)
}
if c.cloudEnvironment == nil || c.cloudEnvironment.Name != "AzurePublicCloud" {
t.Errorf("Expected 'cloudEnvironment' to be set to 'AzurePublicCloud', but got '%s'.", c.cloudEnvironment)
}
}
func TestConfigInstantiatesCorrectAzureEnvironment(t *testing.T) {
config := map[string]string{
"capture_name_prefix": "ignore",
"capture_container_name": "ignore",
"image_offer": "ignore",
"image_publisher": "ignore",
"image_sku": "ignore",
"location": "ignore",
"storage_account": "ignore",
"subscription_id": "ignore",
"os_type": constants.Target_Linux,
}
// user input is fun :)
var table = []struct {
name string
environmentName string
}{
{"China", "AzureChinaCloud"},
{"ChinaCloud", "AzureChinaCloud"},
{"AzureChinaCloud", "AzureChinaCloud"},
{"aZuReChInAcLoUd", "AzureChinaCloud"},
{"USGovernment", "AzureUSGovernmentCloud"},
{"USGovernmentCloud", "AzureUSGovernmentCloud"},
{"AzureUSGovernmentCloud", "AzureUSGovernmentCloud"},
{"aZuReUsGoVeRnMeNtClOuD", "AzureUSGovernmentCloud"},
{"Public", "AzurePublicCloud"},
{"PublicCloud", "AzurePublicCloud"},
{"AzurePublicCloud", "AzurePublicCloud"},
{"aZuRePuBlIcClOuD", "AzurePublicCloud"},
}
packerConfiguration := getPackerConfiguration()
for _, x := range table {
config["cloud_environment_name"] = x.name
c, _, _ := newConfig(config, packerConfiguration)
if c.cloudEnvironment == nil || c.cloudEnvironment.Name != x.environmentName {
t.Errorf("Expected 'cloudEnvironment' to be set to '%s', but got '%s'.", x.environmentName, c.cloudEnvironment)
}
}
}
func TestUserShouldProvideRequiredValues(t *testing.T) {
builderValues := getArmBuilderConfiguration()
// Ensure we can successfully create a config.
_, _, err := newConfig(getArmBuilderConfigurationFromMap(builderValues), getPackerConfiguration())
_, _, err := newConfig(builderValues, getPackerConfiguration())
if err != nil {
t.Errorf("Expected configuration creation to succeed, but it failed!\n")
t.Fatalf(" -> %+v\n", builderValues)
@ -103,15 +170,16 @@ func TestUserShouldProvideRequiredValues(t *testing.T) {
// Take away a required element, and ensure construction fails.
for _, v := range requiredConfigValues {
originalValue := builderValues[v]
delete(builderValues, v)
_, _, err := newConfig(getArmBuilderConfigurationFromMap(builderValues), getPackerConfiguration())
_, _, err := newConfig(builderValues, getPackerConfiguration())
if err == nil {
t.Errorf("Expected configuration creation to fail, but it succeeded!\n")
t.Fatalf(" -> %+v\n", builderValues)
}
builderValues[v] = "--some-value--"
builderValues[v] = originalValue
}
}
@ -167,8 +235,8 @@ func TestConfigShouldTransformToTemplateParameters(t *testing.T) {
t.Errorf("Expected OSDiskName to be equal to config's OSDiskName, but they were '%s' and '%s' respectively.", templateParameters.OSDiskName.Value, c.tmpOSDiskName)
}
if templateParameters.StorageAccountName.Value != c.StorageAccount {
t.Errorf("Expected StorageAccountName to be equal to config's StorageAccountName, but they were '%s' and '%s' respectively.", templateParameters.StorageAccountName.Value, c.StorageAccount)
if templateParameters.StorageAccountBlobEndpoint.Value != c.storageAccountBlobEndpoint {
t.Errorf("Expected StorageAccountBlobEndpoint to be equal to config's storageAccountBlobEndpoint, but they were '%s' and '%s' respectively.", templateParameters.StorageAccountBlobEndpoint.Value, c.storageAccountBlobEndpoint)
}
if templateParameters.VMName.Value != c.tmpComputeName {
@ -180,6 +248,50 @@ func TestConfigShouldTransformToTemplateParameters(t *testing.T) {
}
}
func TestConfigShouldTransformToTemplateParametersLinux(t *testing.T) {
c, _, _ := newConfig(getArmBuilderConfiguration(), getPackerConfiguration())
c.OSType = constants.Target_Linux
templateParameters := c.toTemplateParameters()
if templateParameters.KeyVaultSecretValue != nil {
t.Errorf("Expected KeyVaultSecretValue to be empty for an os_type == '%s', but it was not.", c.OSType)
}
if templateParameters.ObjectId != nil {
t.Errorf("Expected ObjectId to be empty for an os_type == '%s', but it was not.", c.OSType)
}
if templateParameters.TenantId != nil {
t.Errorf("Expected TenantId to be empty for an os_type == '%s', but it was not.", c.OSType)
}
}
func TestConfigShouldTransformToTemplateParametersWindows(t *testing.T) {
c, _, _ := newConfig(getArmBuilderConfiguration(), getPackerConfiguration())
c.OSType = constants.Target_Windows
templateParameters := c.toTemplateParameters()
if templateParameters.SshAuthorizedKey != nil {
t.Errorf("Expected SshAuthorizedKey to be empty for an os_type == '%s', but it was not.", c.OSType)
}
if templateParameters.KeyVaultName == nil {
t.Errorf("Expected KeyVaultName to not be empty for an os_type == '%s', but it was not.", c.OSType)
}
if templateParameters.KeyVaultSecretValue == nil {
t.Errorf("Expected KeyVaultSecretValue to not be empty for an os_type == '%s', but it was not.", c.OSType)
}
if templateParameters.ObjectId == nil {
t.Errorf("Expected ObjectId to not be empty for an os_type == '%s', but it was not.", c.OSType)
}
if templateParameters.TenantId == nil {
t.Errorf("Expected TenantId to not be empty for an os_type == '%s', but it was not.", c.OSType)
}
}
func TestConfigShouldTransformToVirtualMachineCaptureParameters(t *testing.T) {
c, _, _ := newConfig(getArmBuilderConfiguration(), getPackerConfiguration())
parameters := c.toVirtualMachineCaptureParameters()
@ -216,30 +328,67 @@ func TestConfigShouldSupportPackersConfigElements(t *testing.T) {
}
}
func getArmBuilderConfiguration() interface{} {
func TestUserDeviceLoginIsEnabledForLinux(t *testing.T) {
config := map[string]string{
"capture_name_prefix": "ignore",
"capture_container_name": "ignore",
"image_offer": "ignore",
"image_publisher": "ignore",
"image_sku": "ignore",
"location": "ignore",
"storage_account": "ignore",
"subscription_id": "ignore",
"os_type": constants.Target_Linux,
}
_, _, err := newConfig(config, getPackerConfiguration())
if err != nil {
t.Fatalf("failed to use device login for Linux: %s", err)
}
}
func TestUseDeviceLoginIsDisabledForWindows(t *testing.T) {
config := map[string]string{
"capture_name_prefix": "ignore",
"capture_container_name": "ignore",
"image_offer": "ignore",
"image_publisher": "ignore",
"image_sku": "ignore",
"location": "ignore",
"storage_account": "ignore",
"subscription_id": "ignore",
"os_type": constants.Target_Windows,
}
_, _, err := newConfig(config, getPackerConfiguration())
if err == nil {
t.Fatalf("Expected test to fail, but it succeeded")
}
multiError, _ := err.(*packer.MultiError)
if len(multiError.Errors) != 3 {
t.Errorf("Expected to find 3 errors, but found %d errors", len(multiError.Errors))
}
if !strings.Contains(err.Error(), "client_id must be specified") {
t.Errorf("Expected to find error for 'client_id must be specified")
}
if !strings.Contains(err.Error(), "client_secret must be specified") {
t.Errorf("Expected to find error for 'client_secret must be specified")
}
if !strings.Contains(err.Error(), "tenant_id must be specified") {
t.Errorf("Expected to find error for 'tenant_id must be specified")
}
}
func getArmBuilderConfiguration() map[string]string {
m := make(map[string]string)
for _, v := range requiredConfigValues {
m[v] = fmt.Sprintf("%s00", v)
}
return getArmBuilderConfigurationFromMap(m)
}
func getArmBuilderConfigurationFromMap(kvp map[string]string) interface{} {
bs := bytes.NewBufferString("{")
for k, v := range kvp {
bs.WriteString(fmt.Sprintf("\"%s\": \"%s\",\n", k, v))
}
// remove the trailing ",\n" because JSON
bs.Truncate(bs.Len() - 2)
bs.WriteString("}")
var config interface{}
json.Unmarshal([]byte(bs.String()), &config)
return config
m["os_type"] = constants.Target_Linux
return m
}
func getPackerConfiguration() interface{} {
@ -260,14 +409,11 @@ func getPackerConfiguration() interface{} {
return config
}
func getPackerCommunicatorConfiguration() interface{} {
var doc = `{
"ssh_timeout": "1h",
"winrm_timeout": "2h"
}`
var config interface{}
json.Unmarshal([]byte(doc), &config)
func getPackerCommunicatorConfiguration() map[string]string {
config := map[string]string{
"ssh_timeout": "1h",
"winrm_timeout": "2h",
}
return config
}

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
@ -79,13 +79,13 @@ func TestMalformedTemplatesShouldReturnError(t *testing.T) {
func getTemplateParameters() TemplateParameters {
templateParameters := TemplateParameters{
AdminUsername: &TemplateParameter{"adminusername00"},
DnsNameForPublicIP: &TemplateParameter{"dnsnameforpublicip00"},
OSDiskName: &TemplateParameter{"osdiskname00"},
SshAuthorizedKey: &TemplateParameter{"sshkeydata00"},
StorageAccountName: &TemplateParameter{"storageaccountname00"},
VMName: &TemplateParameter{"vmname00"},
VMSize: &TemplateParameter{"vmsize00"},
AdminUsername: &TemplateParameter{"adminusername00"},
DnsNameForPublicIP: &TemplateParameter{"dnsnameforpublicip00"},
OSDiskName: &TemplateParameter{"osdiskname00"},
SshAuthorizedKey: &TemplateParameter{"sshkeydata00"},
StorageAccountBlobEndpoint: &TemplateParameter{"storageaccountblobendpoint00"},
VMName: &TemplateParameter{"vmname00"},
VMSize: &TemplateParameter{"vmsize00"},
}
return templateParameters

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm

View File

@ -0,0 +1,70 @@
package arm
import (
"bytes"
"io/ioutil"
"log"
"net/http"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/mitchellh/packer/builder/azure/common/logutil"
"io"
)
func chop(data []byte, maxlen int64) string {
s := string(data)
if int64(len(s)) > maxlen {
s = s[:maxlen] + "..."
}
return s
}
func handleBody(body io.ReadCloser, maxlen int64) (io.ReadCloser, string) {
if body == nil {
return nil, ""
}
defer body.Close()
b, err := ioutil.ReadAll(body)
if err != nil {
return nil, ""
}
return ioutil.NopCloser(bytes.NewReader(b)), chop(b, maxlen)
}
func withInspection(maxlen int64) autorest.PrepareDecorator {
return func(p autorest.Preparer) autorest.Preparer {
return autorest.PreparerFunc(func(r *http.Request) (*http.Request, error) {
body, bodyString := handleBody(r.Body, maxlen)
r.Body = body
log.Print("Azure request", logutil.Fields{
"method": r.Method,
"request": r.URL.String(),
"body": bodyString,
})
return p.Prepare(r)
})
}
}
func byInspecting(maxlen int64) autorest.RespondDecorator {
return func(r autorest.Responder) autorest.Responder {
return autorest.ResponderFunc(func(resp *http.Response) error {
body, bodyString := handleBody(resp.Body, maxlen)
resp.Body = body
log.Print("Azure response", logutil.Fields{
"status": resp.Status,
"method": resp.Request.Method,
"request": resp.Request.URL.String(),
"x-ms-request-id": azure.ExtractRequestID(resp),
"body": bodyString,
})
return r.Respond(resp)
})
}
}

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
@ -10,9 +10,8 @@ import (
"encoding/base64"
"encoding/pem"
"fmt"
"time"
"golang.org/x/crypto/ssh"
"time"
)
const (

View File

@ -1,14 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
import (
"testing"
"golang.org/x/crypto/ssh"
"testing"
)
func TestFart(t *testing.T) {
}
func TestAuthorizedKeyShouldParse(t *testing.T) {
testSubject, err := NewOpenSshKeyPairWithSize(512)
if err != nil {

33
builder/azure/arm/step.go Normal file
View File

@ -0,0 +1,33 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
import (
"github.com/mitchellh/packer/builder/azure/common"
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/multistep"
)
func processInterruptibleResult(
result common.InterruptibleTaskResult, sayError func(error), state multistep.StateBag) multistep.StepAction {
if result.IsCancelled {
return multistep.ActionHalt
}
return processStepResult(result.Err, sayError, state)
}
func processStepResult(
err error, sayError func(error), state multistep.StateBag) multistep.StepAction {
if err != nil {
state.Put(constants.Error, err)
sayError(err)
return multistep.ActionHalt
}
return multistep.ActionContinue
}

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
@ -7,14 +7,16 @@ import (
"fmt"
"github.com/Azure/azure-sdk-for-go/arm/compute"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/builder/azure/common"
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/packer"
)
type StepCaptureImage struct {
client *AzureClient
capture func(resourceGroupName string, computeName string, parameters *compute.VirtualMachineCaptureParameters) error
capture func(resourceGroupName string, computeName string, parameters *compute.VirtualMachineCaptureParameters, cancelCh <-chan struct{}) error
get func(client *AzureClient) *CaptureTemplate
say func(message string)
error func(e error)
}
@ -22,6 +24,7 @@ type StepCaptureImage struct {
func NewStepCaptureImage(client *AzureClient, ui packer.Ui) *StepCaptureImage {
var step = &StepCaptureImage{
client: client,
get: func(client *AzureClient) *CaptureTemplate { return client.Template },
say: func(message string) { ui.Say(message) },
error: func(e error) { ui.Error(e.Error()) },
}
@ -30,20 +33,17 @@ func NewStepCaptureImage(client *AzureClient, ui packer.Ui) *StepCaptureImage {
return step
}
func (s *StepCaptureImage) captureImage(resourceGroupName string, computeName string, parameters *compute.VirtualMachineCaptureParameters) error {
generalizeResponse, err := s.client.Generalize(resourceGroupName, computeName)
func (s *StepCaptureImage) captureImage(resourceGroupName string, computeName string, parameters *compute.VirtualMachineCaptureParameters, cancelCh <-chan struct{}) error {
_, err := s.client.Generalize(resourceGroupName, computeName)
if err != nil {
return err
}
s.client.VirtualMachinesClient.PollAsNeeded(generalizeResponse.Response)
captureResponse, err := s.client.Capture(resourceGroupName, computeName, *parameters)
_, err = s.client.Capture(resourceGroupName, computeName, *parameters, cancelCh)
if err != nil {
return err
}
s.client.VirtualMachinesClient.PollAsNeeded(captureResponse.Response.Response)
return nil
}
@ -57,15 +57,24 @@ func (s *StepCaptureImage) Run(state multistep.StateBag) multistep.StepAction {
s.say(fmt.Sprintf(" -> ResourceGroupName : '%s'", resourceGroupName))
s.say(fmt.Sprintf(" -> ComputeName : '%s'", computeName))
err := s.capture(resourceGroupName, computeName, parameters)
if err != nil {
state.Put(constants.Error, err)
s.error(err)
result := common.StartInterruptibleTask(
func() bool { return common.IsStateCancelled(state) },
func(cancelCh <-chan struct{}) error {
return s.capture(resourceGroupName, computeName, parameters, cancelCh)
})
return multistep.ActionHalt
}
// HACK(chrboum): I do not like this. The capture method should be returning this value
// instead having to pass in another lambda. I'm in this pickle because I am using
// common.StartInterruptibleTask which is not parametric, and only returns a type of error.
// I could change it to interface{}, but I do not like that solution either.
//
// Having to resort to capturing the template via an inspector is hack, and once I can
// resolve that I can cleanup this code too. See the comments in azure_client.go for more
// details.
template := s.get(s.client)
state.Put(constants.ArmCaptureTemplate, template)
return multistep.ActionContinue
return processInterruptibleResult(result, s.error, state)
}
func (*StepCaptureImage) Cleanup(multistep.StateBag) {

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
@ -8,16 +8,18 @@ import (
"testing"
"github.com/Azure/azure-sdk-for-go/arm/compute"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/multistep"
)
func TestStepCaptureImageShouldFailIfCaptureFails(t *testing.T) {
var testSubject = &StepCaptureImage{
capture: func(string, string, *compute.VirtualMachineCaptureParameters) error {
capture: func(string, string, *compute.VirtualMachineCaptureParameters, <-chan struct{}) error {
return fmt.Errorf("!! Unit Test FAIL !!")
},
get: func(client *AzureClient) *CaptureTemplate {
return nil
},
say: func(message string) {},
error: func(e error) {},
}
@ -36,9 +38,12 @@ func TestStepCaptureImageShouldFailIfCaptureFails(t *testing.T) {
func TestStepCaptureImageShouldPassIfCapturePasses(t *testing.T) {
var testSubject = &StepCaptureImage{
capture: func(string, string, *compute.VirtualMachineCaptureParameters) error { return nil },
say: func(message string) {},
error: func(e error) {},
capture: func(string, string, *compute.VirtualMachineCaptureParameters, <-chan struct{}) error { return nil },
get: func(client *AzureClient) *CaptureTemplate {
return nil
},
say: func(message string) {},
error: func(e error) {},
}
stateBag := createTestStateBagStepCaptureImage()
@ -54,18 +59,27 @@ func TestStepCaptureImageShouldPassIfCapturePasses(t *testing.T) {
}
func TestStepCaptureImageShouldTakeStepArgumentsFromStateBag(t *testing.T) {
cancelCh := make(chan<- struct{})
defer close(cancelCh)
var actualResourceGroupName string
var actualComputeName string
var actualVirtualMachineCaptureParameters *compute.VirtualMachineCaptureParameters
actualCaptureTemplate := &CaptureTemplate{
Schema: "!! Unit Test !!",
}
var testSubject = &StepCaptureImage{
capture: func(resourceGroupName string, computeName string, parameters *compute.VirtualMachineCaptureParameters) error {
capture: func(resourceGroupName string, computeName string, parameters *compute.VirtualMachineCaptureParameters, cancelCh <-chan struct{}) error {
actualResourceGroupName = resourceGroupName
actualComputeName = computeName
actualVirtualMachineCaptureParameters = parameters
return nil
},
get: func(client *AzureClient) *CaptureTemplate {
return actualCaptureTemplate
},
say: func(message string) {},
error: func(e error) {},
}
@ -80,6 +94,7 @@ func TestStepCaptureImageShouldTakeStepArgumentsFromStateBag(t *testing.T) {
var expectedComputeName = stateBag.Get(constants.ArmComputeName).(string)
var expectedResourceGroupName = stateBag.Get(constants.ArmResourceGroupName).(string)
var expectedVirtualMachineCaptureParameters = stateBag.Get(constants.ArmVirtualMachineCaptureParameters).(*compute.VirtualMachineCaptureParameters)
var expectedCaptureTemplate = stateBag.Get(constants.ArmCaptureTemplate).(*CaptureTemplate)
if actualComputeName != expectedComputeName {
t.Fatalf("Expected StepCaptureImage to source 'constants.ArmComputeName' from the state bag, but it did not.")
@ -92,6 +107,10 @@ func TestStepCaptureImageShouldTakeStepArgumentsFromStateBag(t *testing.T) {
if actualVirtualMachineCaptureParameters != expectedVirtualMachineCaptureParameters {
t.Fatalf("Expected StepCaptureImage to source 'constants.ArmVirtualMachineCaptureParameters' from the state bag, but it did not.")
}
if actualCaptureTemplate != expectedCaptureTemplate {
t.Fatalf("Expected StepCaptureImage to source 'constants.ArmCaptureTemplate' from the state bag, but it did not.")
}
}
func createTestStateBagStepCaptureImage() multistep.StateBag {

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
@ -7,8 +7,8 @@ import (
"fmt"
"github.com/Azure/azure-sdk-for-go/arm/resources/resources"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/packer"
)
@ -48,15 +48,29 @@ func (s *StepCreateResourceGroup) Run(state multistep.StateBag) multistep.StepAc
s.say(fmt.Sprintf(" -> Location : '%s'", location))
err := s.create(resourceGroupName, location)
if err != nil {
state.Put(constants.Error, err)
s.error(err)
return multistep.ActionHalt
if err == nil {
state.Put(constants.ArmIsResourceGroupCreated, true)
}
return multistep.ActionContinue
return processStepResult(err, s.error, state)
}
func (*StepCreateResourceGroup) Cleanup(multistep.StateBag) {
func (s *StepCreateResourceGroup) Cleanup(state multistep.StateBag) {
_, ok := state.GetOk(constants.ArmIsResourceGroupCreated)
if !ok {
return
}
ui := state.Get("ui").(packer.Ui)
ui.Say("\nCleanup requested, deleting resource group ...")
var resourceGroupName = state.Get(constants.ArmResourceGroupName).(string)
_, err := s.client.GroupsClient.Delete(resourceGroupName, nil)
if err != nil {
ui.Error(fmt.Sprintf("Error deleting resource group. Please delete it manually.\n\n"+
"Name: %s\n"+
"Error: %s", resourceGroupName, err))
}
ui.Say("Resource group has been deleted.")
}

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
@ -7,8 +7,8 @@ import (
"fmt"
"testing"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/multistep"
)
func TestStepCreateResourceGroupShouldFailIfCreateFails(t *testing.T) {
@ -80,6 +80,11 @@ func TestStepCreateResourceGroupShouldTakeStepArgumentsFromStateBag(t *testing.T
if actualLocation != expectedLocation {
t.Fatalf("Expected the step to source 'constants.ArmResourceGroupName' from the state bag, but it did not.")
}
_, ok := stateBag.GetOk(constants.ArmIsResourceGroupCreated)
if !ok {
t.Fatalf("Expected the step to add item to stateBag['constants.ArmIsResourceGroupCreated'], but it did not.")
}
}
func createTestStateBagStepCreateResourceGroup() multistep.StateBag {

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
@ -33,7 +33,7 @@ func NewStepDeleteOSDisk(client *AzureClient, ui packer.Ui) *StepDeleteOSDisk {
}
func (s *StepDeleteOSDisk) deleteBlob(storageContainerName string, blobName string) error {
return s.client.BlobStorageClient.DeleteBlob(storageContainerName, blobName)
return s.client.BlobStorageClient.DeleteBlob(storageContainerName, blobName, nil)
}
func (s *StepDeleteOSDisk) Run(state multistep.StateBag) multistep.StepAction {
@ -54,13 +54,7 @@ func (s *StepDeleteOSDisk) Run(state multistep.StateBag) multistep.StepAction {
var blobName = strings.Join(xs[2:], "/")
err = s.delete(storageAccountName, blobName)
if err != nil {
state.Put(constants.Error, err)
s.error(err)
return multistep.ActionHalt
}
return multistep.ActionContinue
return processStepResult(err, s.error, state)
}
func (*StepDeleteOSDisk) Cleanup(multistep.StateBag) {

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
@ -7,8 +7,8 @@ import (
"fmt"
"testing"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/multistep"
)
func TestStepDeleteOSDiskShouldFailIfGetFails(t *testing.T) {

View File

@ -1,19 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
import (
"fmt"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/builder/azure/common"
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/packer"
)
type StepDeleteResourceGroup struct {
client *AzureClient
delete func(resourceGroupName string) error
delete func(resourceGroupName string, cancelCh <-chan struct{}) error
say func(message string)
error func(e error)
}
@ -29,32 +30,23 @@ func NewStepDeleteResourceGroup(client *AzureClient, ui packer.Ui) *StepDeleteRe
return step
}
func (s *StepDeleteResourceGroup) deleteResourceGroup(resourceGroupName string) error {
res, err := s.client.GroupsClient.Delete(resourceGroupName)
if err != nil {
return err
}
func (s *StepDeleteResourceGroup) deleteResourceGroup(resourceGroupName string, cancelCh <-chan struct{}) error {
_, err := s.client.GroupsClient.Delete(resourceGroupName, cancelCh)
s.client.GroupsClient.PollAsNeeded(res.Response)
return nil
return err
}
func (s *StepDeleteResourceGroup) Run(state multistep.StateBag) multistep.StepAction {
s.say("Deleting resource group ...")
var resourceGroupName = state.Get(constants.ArmResourceGroupName).(string)
s.say(fmt.Sprintf(" -> ResourceGroupName : '%s'", resourceGroupName))
err := s.delete(resourceGroupName)
if err != nil {
state.Put(constants.Error, err)
s.error(err)
result := common.StartInterruptibleTask(
func() bool { return common.IsStateCancelled(state) },
func(cancelCh <-chan struct{}) error { return s.delete(resourceGroupName, cancelCh) })
return multistep.ActionHalt
}
return multistep.ActionContinue
return processInterruptibleResult(result, s.error, state)
}
func (*StepDeleteResourceGroup) Cleanup(multistep.StateBag) {

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
@ -7,13 +7,13 @@ import (
"fmt"
"testing"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/multistep"
)
func TestStepDeleteResourceGroupShouldFailIfDeleteFails(t *testing.T) {
var testSubject = &StepDeleteResourceGroup{
delete: func(string) error { return fmt.Errorf("!! Unit Test FAIL !!") },
delete: func(string, <-chan struct{}) error { return fmt.Errorf("!! Unit Test FAIL !!") },
say: func(message string) {},
error: func(e error) {},
}
@ -32,7 +32,7 @@ func TestStepDeleteResourceGroupShouldFailIfDeleteFails(t *testing.T) {
func TestStepDeleteResourceGroupShouldPassIfDeletePasses(t *testing.T) {
var testSubject = &StepDeleteResourceGroup{
delete: func(string) error { return nil },
delete: func(string, <-chan struct{}) error { return nil },
say: func(message string) {},
error: func(e error) {},
}
@ -53,7 +53,7 @@ func TestStepDeleteResourceGroupShouldTakeStepArgumentsFromStateBag(t *testing.T
var actualResourceGroupName string
var testSubject = &StepDeleteResourceGroup{
delete: func(resourceGroupName string) error {
delete: func(resourceGroupName string, cancelCh <-chan struct{}) error {
actualResourceGroupName = resourceGroupName
return nil
},

View File

@ -1,47 +1,49 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
import (
"fmt"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/builder/azure/common"
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/packer"
)
type StepDeployTemplate struct {
client *AzureClient
deploy func(resourceGroupName string, deploymentName string, templateParameters *TemplateParameters) error
say func(message string)
error func(e error)
client *AzureClient
template string
deploy func(resourceGroupName string, deploymentName string, templateParameters *TemplateParameters, cancelCh <-chan struct{}) error
say func(message string)
error func(e error)
}
func NewStepDeployTemplate(client *AzureClient, ui packer.Ui) *StepDeployTemplate {
func NewStepDeployTemplate(client *AzureClient, ui packer.Ui, template string) *StepDeployTemplate {
var step = &StepDeployTemplate{
client: client,
say: func(message string) { ui.Say(message) },
error: func(e error) { ui.Error(e.Error()) },
client: client,
template: template,
say: func(message string) { ui.Say(message) },
error: func(e error) { ui.Error(e.Error()) },
}
step.deploy = step.deployTemplate
return step
}
func (s *StepDeployTemplate) deployTemplate(resourceGroupName string, deploymentName string, templateParameters *TemplateParameters) error {
factory := newDeploymentFactory(Linux)
func (s *StepDeployTemplate) deployTemplate(resourceGroupName string, deploymentName string, templateParameters *TemplateParameters, cancelCh <-chan struct{}) error {
factory := newDeploymentFactory(s.template)
deployment, err := factory.create(*templateParameters)
if err != nil {
return err
}
res, err := s.client.DeploymentsClient.CreateOrUpdate(resourceGroupName, deploymentName, *deployment)
_, err = s.client.DeploymentsClient.CreateOrUpdate(resourceGroupName, deploymentName, *deployment, cancelCh)
if err != nil {
return err
}
s.client.DeploymentsClient.PollAsNeeded(res.Response.Response)
poller := NewDeploymentPoller(func() (string, error) {
r, e := s.client.DeploymentsClient.Get(resourceGroupName, deploymentName)
if r.Properties != nil && r.Properties.ProvisioningState != nil {
@ -73,15 +75,14 @@ func (s *StepDeployTemplate) Run(state multistep.StateBag) multistep.StepAction
s.say(fmt.Sprintf(" -> ResourceGroupName : '%s'", resourceGroupName))
s.say(fmt.Sprintf(" -> DeploymentName : '%s'", deploymentName))
err := s.deploy(resourceGroupName, deploymentName, templateParameters)
if err != nil {
state.Put(constants.Error, err)
s.error(err)
result := common.StartInterruptibleTask(
func() bool { return common.IsStateCancelled(state) },
func(cancelCh <-chan struct{}) error {
return s.deploy(resourceGroupName, deploymentName, templateParameters, cancelCh)
},
)
return multistep.ActionHalt
}
return multistep.ActionContinue
return processInterruptibleResult(result, s.error, state)
}
func (*StepDeployTemplate) Cleanup(multistep.StateBag) {

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
@ -7,15 +7,18 @@ import (
"fmt"
"testing"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/multistep"
)
func TestStepDeployTemplateShouldFailIfDeployFails(t *testing.T) {
var testSubject = &StepDeployTemplate{
deploy: func(string, string, *TemplateParameters) error { return fmt.Errorf("!! Unit Test FAIL !!") },
say: func(message string) {},
error: func(e error) {},
template: Linux,
deploy: func(string, string, *TemplateParameters, <-chan struct{}) error {
return fmt.Errorf("!! Unit Test FAIL !!")
},
say: func(message string) {},
error: func(e error) {},
}
stateBag := createTestStateBagStepDeployTemplate()
@ -32,9 +35,10 @@ func TestStepDeployTemplateShouldFailIfDeployFails(t *testing.T) {
func TestStepDeployTemplateShouldPassIfDeployPasses(t *testing.T) {
var testSubject = &StepDeployTemplate{
deploy: func(string, string, *TemplateParameters) error { return nil },
say: func(message string) {},
error: func(e error) {},
template: Linux,
deploy: func(string, string, *TemplateParameters, <-chan struct{}) error { return nil },
say: func(message string) {},
error: func(e error) {},
}
stateBag := createTestStateBagStepDeployTemplate()
@ -55,7 +59,8 @@ func TestStepDeployTemplateShouldTakeStepArgumentsFromStateBag(t *testing.T) {
var actualTemplateParameters *TemplateParameters
var testSubject = &StepDeployTemplate{
deploy: func(resourceGroupName string, deploymentName string, templateParameter *TemplateParameters) error {
template: Linux,
deploy: func(resourceGroupName string, deploymentName string, templateParameter *TemplateParameters, cancelCh <-chan struct{}) error {
actualResourceGroupName = resourceGroupName
actualDeploymentName = deploymentName
actualTemplateParameters = templateParameter

View File

@ -0,0 +1,79 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
import (
"fmt"
"time"
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/packer"
)
type StepGetCertificate struct {
client *AzureClient
template string
get func(keyVaultName string, secretName string) (string, error)
say func(message string)
error func(e error)
pause func()
}
func NewStepGetCertificate(client *AzureClient, ui packer.Ui) *StepGetCertificate {
var step = &StepGetCertificate{
client: client,
say: func(message string) { ui.Say(message) },
error: func(e error) { ui.Error(e.Error()) },
pause: func() { time.Sleep(30 * time.Second) },
}
step.get = step.getCertificateUrl
return step
}
func (s *StepGetCertificate) getCertificateUrl(keyVaultName string, secretName string) (string, error) {
secret, err := s.client.GetSecret(keyVaultName, secretName)
if err != nil {
return "", err
}
return *secret.ID, err
}
func (s *StepGetCertificate) Run(state multistep.StateBag) multistep.StepAction {
s.say("Getting the certificate's URL ...")
var keyVaultName = state.Get(constants.ArmKeyVaultName).(string)
s.say(fmt.Sprintf(" -> Key Vault Name : '%s'", keyVaultName))
s.say(fmt.Sprintf(" -> Key Vault Secret Name : '%s'", DefaultSecretName))
var err error
var url string
for i := 0; i < 5; i++ {
url, err = s.get(keyVaultName, DefaultSecretName)
if err == nil {
break
}
s.say(fmt.Sprintf(" ...failed to get certificate URL, retry(%d)", i))
s.pause()
}
if err != nil {
state.Put(constants.Error, err)
s.error(err)
return multistep.ActionHalt
}
s.say(fmt.Sprintf(" -> Certificate URL : '%s'", url))
state.Put(constants.ArmCertificateUrl, url)
return multistep.ActionContinue
}
func (*StepGetCertificate) Cleanup(multistep.StateBag) {
}

View File

@ -0,0 +1,101 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
import (
"fmt"
"testing"
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/multistep"
)
func TestStepGetCertificateShouldFailIfGetFails(t *testing.T) {
var testSubject = &StepGetCertificate{
get: func(string, string) (string, error) { return "", fmt.Errorf("!! Unit Test FAIL !!") },
say: func(message string) {},
error: func(e error) {},
pause: func() {},
}
stateBag := createTestStateBagStepGetCertificate()
var result = testSubject.Run(stateBag)
if result != multistep.ActionHalt {
t.Fatalf("Expected the step to return 'ActionHalt', but got '%d'.", result)
}
if _, ok := stateBag.GetOk(constants.Error); ok == false {
t.Fatalf("Expected the step to set stateBag['%s'], but it was not.", constants.Error)
}
}
func TestStepGetCertificateShouldPassIfGetPasses(t *testing.T) {
var testSubject = &StepGetCertificate{
get: func(string, string) (string, error) { return "", nil },
say: func(message string) {},
error: func(e error) {},
pause: func() {},
}
stateBag := createTestStateBagStepGetCertificate()
var result = testSubject.Run(stateBag)
if result != multistep.ActionContinue {
t.Fatalf("Expected the step to return 'ActionContinue', but got '%d'.", result)
}
if _, ok := stateBag.GetOk(constants.Error); ok == true {
t.Fatalf("Expected the step to not set stateBag['%s'], but it was.", constants.Error)
}
}
func TestStepGetCertificateShouldTakeStepArgumentsFromStateBag(t *testing.T) {
var actualKeyVaultName string
var actualSecretName string
var testSubject = &StepGetCertificate{
get: func(keyVaultName string, secretName string) (string, error) {
actualKeyVaultName = keyVaultName
actualSecretName = secretName
return "http://key.vault/1", nil
},
say: func(message string) {},
error: func(e error) {},
pause: func() {},
}
stateBag := createTestStateBagStepGetCertificate()
var result = testSubject.Run(stateBag)
if result != multistep.ActionContinue {
t.Fatalf("Expected the step to return 'ActionContinue', but got '%d'.", result)
}
var expectedKeyVaultName = stateBag.Get(constants.ArmKeyVaultName).(string)
if actualKeyVaultName != expectedKeyVaultName {
t.Fatalf("Expected StepGetCertificate to source 'constants.ArmKeyVaultName' from the state bag, but it did not.")
}
if actualSecretName != DefaultSecretName {
t.Fatalf("Expected StepGetCertificate to use default value for secret, but it did not.")
}
expectedCertificateUrl, ok := stateBag.GetOk(constants.ArmCertificateUrl)
if !ok {
t.Fatalf("Expected the state bag to have a value for '%s', but it did not.", constants.ArmCertificateUrl)
}
if expectedCertificateUrl != "http://key.vault/1" {
t.Fatalf("Expected the value of stateBag[%s] to be 'http://key.vault/1', but got '%s'.", constants.ArmCertificateUrl, expectedCertificateUrl)
}
}
func createTestStateBagStepGetCertificate() multistep.StateBag {
stateBag := new(multistep.BasicStateBag)
stateBag.Put(constants.ArmKeyVaultName, "Unit Test: KeyVaultName")
return stateBag
}

View File

@ -1,13 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
import (
"fmt"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/packer"
)
@ -55,7 +55,7 @@ func (s *StepGetIPAddress) Run(state multistep.StateBag) multistep.StepAction {
return multistep.ActionHalt
}
s.say(fmt.Sprintf(" -> SSHHost : '%s'", address))
s.say(fmt.Sprintf(" -> Public IP : '%s'", address))
state.Put(constants.SSHHost, address)
return multistep.ActionContinue

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
@ -7,8 +7,8 @@ import (
"fmt"
"testing"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/multistep"
)
func TestStepGetIPAddressShouldFailIfGetFails(t *testing.T) {
@ -75,11 +75,11 @@ func TestStepGetIPAddressShouldTakeStepArgumentsFromStateBag(t *testing.T) {
var expectedIPAddressName = stateBag.Get(constants.ArmPublicIPAddressName).(string)
if actualIPAddressName != expectedIPAddressName {
t.Fatalf("Expected StepValidateTemplate to source 'constants.ArmIPAddressName' from the state bag, but it did not.")
t.Fatalf("Expected StepGetIPAddress to source 'constants.ArmIPAddressName' from the state bag, but it did not.")
}
if actualResourceGroupName != expectedResourceGroupName {
t.Fatalf("Expected StepValidateTemplate to source 'constants.ArmResourceGroupName' from the state bag, but it did not.")
t.Fatalf("Expected StepGetIPAddress to source 'constants.ArmResourceGroupName' from the state bag, but it did not.")
}
expectedIPAddress, ok := stateBag.GetOk(constants.SSHHost)

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
@ -95,7 +95,7 @@ func TestStepGetOSDiskShouldTakeValidateArgumentsFromStateBag(t *testing.T) {
}
if expectedOSDiskVhd != "test.vhd" {
t.Fatalf("Expected the value of stateBag[%s] to be '127.0.0.1', but got '%s'.", constants.ArmOSDiskVhd, expectedOSDiskVhd)
t.Fatalf("Expected the value of stateBag[%s] to be 'test.vhd', but got '%s'.", constants.ArmOSDiskVhd, expectedOSDiskVhd)
}
}

View File

@ -1,19 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
import (
"fmt"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/builder/azure/common"
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/packer"
)
type StepPowerOffCompute struct {
client *AzureClient
powerOff func(resourceGroupName string, computeName string) error
powerOff func(resourceGroupName string, computeName string, cancelCh <-chan struct{}) error
say func(message string)
error func(e error)
}
@ -29,13 +30,12 @@ func NewStepPowerOffCompute(client *AzureClient, ui packer.Ui) *StepPowerOffComp
return step
}
func (s *StepPowerOffCompute) powerOffCompute(resourceGroupName string, computeName string) error {
res, err := s.client.PowerOff(resourceGroupName, computeName)
func (s *StepPowerOffCompute) powerOffCompute(resourceGroupName string, computeName string, cancelCh <-chan struct{}) error {
_, err := s.client.PowerOff(resourceGroupName, computeName, cancelCh)
if err != nil {
return err
}
s.client.VirtualMachinesClient.PollAsNeeded(res.Response)
return nil
}
@ -48,15 +48,11 @@ func (s *StepPowerOffCompute) Run(state multistep.StateBag) multistep.StepAction
s.say(fmt.Sprintf(" -> ResourceGroupName : '%s'", resourceGroupName))
s.say(fmt.Sprintf(" -> ComputeName : '%s'", computeName))
err := s.powerOff(resourceGroupName, computeName)
if err != nil {
state.Put(constants.Error, err)
s.error(err)
result := common.StartInterruptibleTask(
func() bool { return common.IsStateCancelled(state) },
func(cancelCh <-chan struct{}) error { return s.powerOff(resourceGroupName, computeName, cancelCh) })
return multistep.ActionHalt
}
return multistep.ActionContinue
return processInterruptibleResult(result, s.error, state)
}
func (*StepPowerOffCompute) Cleanup(multistep.StateBag) {

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
@ -7,13 +7,13 @@ import (
"fmt"
"testing"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/multistep"
)
func TestStepPowerOffComputeShouldFailIfPowerOffFails(t *testing.T) {
var testSubject = &StepPowerOffCompute{
powerOff: func(string, string) error { return fmt.Errorf("!! Unit Test FAIL !!") },
powerOff: func(string, string, <-chan struct{}) error { return fmt.Errorf("!! Unit Test FAIL !!") },
say: func(message string) {},
error: func(e error) {},
}
@ -32,7 +32,7 @@ func TestStepPowerOffComputeShouldFailIfPowerOffFails(t *testing.T) {
func TestStepPowerOffComputeShouldPassIfPowerOffPasses(t *testing.T) {
var testSubject = &StepPowerOffCompute{
powerOff: func(string, string) error { return nil },
powerOff: func(string, string, <-chan struct{}) error { return nil },
say: func(message string) {},
error: func(e error) {},
}
@ -54,7 +54,7 @@ func TestStepPowerOffComputeShouldTakeStepArgumentsFromStateBag(t *testing.T) {
var actualComputeName string
var testSubject = &StepPowerOffCompute{
powerOff: func(resourceGroupName string, computeName string) error {
powerOff: func(resourceGroupName string, computeName string, cancelCh <-chan struct{}) error {
actualResourceGroupName = resourceGroupName
actualComputeName = computeName

View File

@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
import (
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/packer"
)
type StepSetCertificate struct {
config *Config
say func(message string)
error func(e error)
}
func NewStepSetCertificate(config *Config, ui packer.Ui) *StepSetCertificate {
var step = &StepSetCertificate{
config: config,
say: func(message string) { ui.Say(message) },
error: func(e error) { ui.Error(e.Error()) },
}
return step
}
func (s *StepSetCertificate) Run(state multistep.StateBag) multistep.StepAction {
s.say("Setting the certificate's URL ...")
var winRMCertificateUrl = state.Get(constants.ArmCertificateUrl).(string)
s.config.tmpWinRMCertificateUrl = winRMCertificateUrl
state.Put(constants.ArmTemplateParameters, s.config.toTemplateParameters())
return multistep.ActionContinue
}
func (*StepSetCertificate) Cleanup(multistep.StateBag) {
}

View File

@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
import (
"testing"
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/multistep"
)
func TestStepSetCertificateShouldPassIfGetPasses(t *testing.T) {
var testSubject = &StepSetCertificate{
config: new(Config),
say: func(message string) {},
error: func(e error) {},
}
stateBag := createTestStateBagStepSetCertificate()
var result = testSubject.Run(stateBag)
if result != multistep.ActionContinue {
t.Fatalf("Expected the step to return 'ActionContinue', but got '%d'.", result)
}
if _, ok := stateBag.GetOk(constants.Error); ok == true {
t.Fatalf("Expected the step to not set stateBag['%s'], but it was.", constants.Error)
}
}
func TestStepSetCertificateShouldTakeStepArgumentsFromStateBag(t *testing.T) {
var testSubject = &StepSetCertificate{
config: new(Config),
say: func(message string) {},
error: func(e error) {},
}
stateBag := createTestStateBagStepSetCertificate()
var result = testSubject.Run(stateBag)
if result != multistep.ActionContinue {
t.Fatalf("Expected the step to return 'ActionContinue', but got '%d'.", result)
}
_, ok := stateBag.GetOk(constants.ArmTemplateParameters)
if !ok {
t.Fatalf("Expected the state bag to have a value for '%s', but it did not.", constants.ArmTemplateParameters)
}
}
func createTestStateBagStepSetCertificate() multistep.StateBag {
stateBag := new(multistep.BasicStateBag)
stateBag.Put(constants.ArmCertificateUrl, "Unit Test: Certificate URL")
return stateBag
}

View File

@ -0,0 +1,99 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
import (
"fmt"
"github.com/mitchellh/packer/builder/azure/common"
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/multistep"
"testing"
)
func TestProcessStepResultShouldContinueForNonErrors(t *testing.T) {
stateBag := new(multistep.BasicStateBag)
code := processStepResult(nil, func(error) { t.Fatalf("Should not be called!") }, stateBag)
if _, ok := stateBag.GetOk(constants.Error); ok {
t.Errorf("Error was nil, but was still in the state bag.")
}
if code != multistep.ActionContinue {
t.Errorf("Expected ActionContinue(%d), but got=%d", multistep.ActionContinue, code)
}
}
func TestProcessStepResultShouldHaltOnError(t *testing.T) {
stateBag := new(multistep.BasicStateBag)
isSaidError := false
code := processStepResult(fmt.Errorf("boom"), func(error) { isSaidError = true }, stateBag)
if _, ok := stateBag.GetOk(constants.Error); !ok {
t.Errorf("Error was non nil, but was not in the state bag.")
}
if !isSaidError {
t.Errorf("Expected error to be said, but it was not.")
}
if code != multistep.ActionHalt {
t.Errorf("Expected ActionHalt(%d), but got=%d", multistep.ActionHalt, code)
}
}
func TestProcessStepResultShouldContinueOnSuccessfulTask(t *testing.T) {
stateBag := new(multistep.BasicStateBag)
result := common.InterruptibleTaskResult{
IsCancelled: false,
Err: nil,
}
code := processInterruptibleResult(result, func(error) { t.Fatalf("Should not be called!") }, stateBag)
if _, ok := stateBag.GetOk(constants.Error); ok {
t.Errorf("Error was nil, but was still in the state bag.")
}
if code != multistep.ActionContinue {
t.Errorf("Expected ActionContinue(%d), but got=%d", multistep.ActionContinue, code)
}
}
func TestProcessStepResultShouldHaltWhenTaskIsCancelled(t *testing.T) {
stateBag := new(multistep.BasicStateBag)
result := common.InterruptibleTaskResult{
IsCancelled: true,
Err: nil,
}
code := processInterruptibleResult(result, func(error) { t.Fatalf("Should not be called!") }, stateBag)
if _, ok := stateBag.GetOk(constants.Error); ok {
t.Errorf("Error was nil, but was still in the state bag.")
}
if code != multistep.ActionHalt {
t.Errorf("Expected ActionHalt(%d), but got=%d", multistep.ActionHalt, code)
}
}
func TestProcessStepResultShouldHaltOnTaskError(t *testing.T) {
stateBag := new(multistep.BasicStateBag)
isSaidError := false
result := common.InterruptibleTaskResult{
IsCancelled: false,
Err: fmt.Errorf("boom"),
}
code := processInterruptibleResult(result, func(error) { isSaidError = true }, stateBag)
if _, ok := stateBag.GetOk(constants.Error); !ok {
t.Errorf("Error was *not* nil, but was not in the state bag.")
}
if !isSaidError {
t.Errorf("Expected error to be said, but it was not.")
}
if code != multistep.ActionHalt {
t.Errorf("Expected ActionHalt(%d), but got=%d", multistep.ActionHalt, code)
}
}

View File

@ -1,28 +1,30 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
import (
"fmt"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/packer"
)
type StepValidateTemplate struct {
client *AzureClient
template string
validate func(resourceGroupName string, deploymentName string, templateParameters *TemplateParameters) error
say func(message string)
error func(e error)
}
func NewStepValidateTemplate(client *AzureClient, ui packer.Ui) *StepValidateTemplate {
func NewStepValidateTemplate(client *AzureClient, ui packer.Ui, template string) *StepValidateTemplate {
var step = &StepValidateTemplate{
client: client,
say: func(message string) { ui.Say(message) },
error: func(e error) { ui.Error(e.Error()) },
client: client,
template: template,
say: func(message string) { ui.Say(message) },
error: func(e error) { ui.Error(e.Error()) },
}
step.validate = step.validateTemplate
@ -30,7 +32,7 @@ func NewStepValidateTemplate(client *AzureClient, ui packer.Ui) *StepValidateTem
}
func (s *StepValidateTemplate) validateTemplate(resourceGroupName string, deploymentName string, templateParameters *TemplateParameters) error {
factory := newDeploymentFactory(Linux)
factory := newDeploymentFactory(s.template)
deployment, err := factory.create(*templateParameters)
if err != nil {
@ -52,14 +54,7 @@ func (s *StepValidateTemplate) Run(state multistep.StateBag) multistep.StepActio
s.say(fmt.Sprintf(" -> DeploymentName : '%s'", deploymentName))
err := s.validate(resourceGroupName, deploymentName, templateParameters)
if err != nil {
state.Put(constants.Error, err)
s.error(err)
return multistep.ActionHalt
}
return multistep.ActionContinue
return processStepResult(err, s.error, state)
}
func (*StepValidateTemplate) Cleanup(multistep.StateBag) {

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
@ -7,13 +7,14 @@ import (
"fmt"
"testing"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/multistep"
)
func TestStepValidateTemplateShouldFailIfValidateFails(t *testing.T) {
var testSubject = &StepValidateTemplate{
template: Linux,
validate: func(string, string, *TemplateParameters) error { return fmt.Errorf("!! Unit Test FAIL !!") },
say: func(message string) {},
error: func(e error) {},
@ -33,6 +34,7 @@ func TestStepValidateTemplateShouldFailIfValidateFails(t *testing.T) {
func TestStepValidateTemplateShouldPassIfValidatePasses(t *testing.T) {
var testSubject = &StepValidateTemplate{
template: Linux,
validate: func(string, string, *TemplateParameters) error { return nil },
say: func(message string) {},
error: func(e error) {},
@ -56,6 +58,7 @@ func TestStepValidateTemplateShouldTakeStepArgumentsFromStateBag(t *testing.T) {
var actualTemplateParameters *TemplateParameters
var testSubject = &StepValidateTemplate{
template: Linux,
validate: func(resourceGroupName string, deploymentName string, templateParameter *TemplateParameters) error {
actualResourceGroupName = resourceGroupName
actualDeploymentName = deploymentName

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
@ -27,13 +27,16 @@ const Linux = `{
"imageSku": {
"type": "string"
},
"imageVersion": {
"type": "string"
},
"osDiskName": {
"type": "string"
},
"sshAuthorizedKey": {
"type": "string"
},
"storageAccountName": {
"storageAccountBlobEndpoint": {
"type": "string"
},
"vmSize": {
@ -151,12 +154,12 @@ const Linux = `{
"publisher": "[parameters('imagePublisher')]",
"offer": "[parameters('imageOffer')]",
"sku": "[parameters('imageSku')]",
"version": "latest"
"version": "[parameters('imageVersion')]"
},
"osDisk": {
"name": "osdisk",
"vhd": {
"uri": "[concat('http://',parameters('storageAccountName'),'.blob.core.windows.net/',variables('vmStorageAccountContainerName'),'/', parameters('osDiskName'),'.vhd')]"
"uri": "[concat(parameters('storageAccountBlobEndpoint'),variables('vmStorageAccountContainerName'),'/', parameters('osDiskName'),'.vhd')]"
},
"caching": "ReadWrite",
"createOption": "FromImage"
@ -178,3 +181,309 @@ const Linux = `{
}
]
}`
// Template to deploy a KeyVault.
//
// NOTE: the parameters for the KeyVault template are identical to Windows
// template. Keeping these values in sync simplifies the code at the expense
// of template bloat. This bloat may be addressed in the future.
const KeyVault = `{
"$schema": "http://schema.management.azure.com/schemas/2014-04-01-preview/deploymentTemplate.json",
"contentVersion": "1.0.0.0",
"parameters": {
"adminUserName": {
"type": "string"
},
"adminPassword": {
"type": "securestring"
},
"dnsNameForPublicIP": {
"type": "string"
},
"imageOffer": {
"type": "string"
},
"imagePublisher": {
"type": "string"
},
"imageSku": {
"type": "string"
},
"imageVersion": {
"type": "string"
},
"keyVaultName": {
"type": "string"
},
"keyVaultSecretValue": {
"type": "securestring"
},
"objectId": {
"type": "string"
},
"osDiskName": {
"type": "string"
},
"storageAccountBlobEndpoint": {
"type": "string"
},
"tenantId": {
"type": "string"
},
"vmName": {
"type": "string"
},
"vmSize": {
"type": "string"
},
"winRMCertificateUrl": {
"type": "string"
}
},
"variables": {
"apiVersion": "2015-06-01",
"location": "[resourceGroup().location]",
"keyVaultSecretName": "packerKeyVaultSecret"
},
"resources": [
{
"apiVersion": "[variables('apiVersion')]",
"type": "Microsoft.KeyVault/vaults",
"name": "[parameters('keyVaultName')]",
"location": "[variables('location')]",
"properties": {
"enabledForDeployment": "true",
"enabledForTemplateDeployment": "true",
"tenantId": "[parameters('tenantId')]",
"accessPolicies": [
{
"tenantId": "[parameters('tenantId')]",
"objectId": "[parameters('objectId')]",
"permissions": {
"keys": [ "all" ],
"secrets": [ "all" ]
}
}
],
"sku": {
"name": "standard",
"family": "A"
}
},
"resources": [
{
"apiVersion": "[variables('apiVersion')]",
"type": "secrets",
"name": "[variables('keyVaultSecretName')]",
"dependsOn": [
"[concat('Microsoft.KeyVault/vaults/', parameters('keyVaultName'))]"
],
"properties": {
"value": "[parameters('keyVaultSecretValue')]"
}
}
]
}
]
}`
const Windows = `{
"$schema": "http://schema.management.azure.com/schemas/2014-04-01-preview/deploymentTemplate.json",
"contentVersion": "1.0.0.0",
"parameters": {
"adminUserName": {
"type": "string"
},
"adminPassword": {
"type": "securestring"
},
"dnsNameForPublicIP": {
"type": "string"
},
"imageOffer": {
"type": "string"
},
"imagePublisher": {
"type": "string"
},
"imageSku": {
"type": "string"
},
"imageVersion": {
"type": "string"
},
"keyVaultName": {
"type": "string"
},
"keyVaultSecretValue": {
"type": "securestring"
},
"objectId": {
"type": "string"
},
"osDiskName": {
"type": "string"
},
"storageAccountBlobEndpoint": {
"type": "string"
},
"tenantId": {
"type": "string"
},
"vmName": {
"type": "string"
},
"vmSize": {
"type": "string"
},
"winRMCertificateUrl": {
"type": "string"
}
},
"variables": {
"addressPrefix": "10.0.0.0/16",
"apiVersion": "2015-06-15",
"location": "[resourceGroup().location]",
"nicName": "packerNic",
"publicIPAddressName": "packerPublicIP",
"publicIPAddressType": "Dynamic",
"subnetName": "packerSubnet",
"subnetAddressPrefix": "10.0.0.0/24",
"subnetRef": "[concat(variables('vnetID'),'/subnets/',variables('subnetName'))]",
"virtualNetworkName": "packerNetwork",
"vmStorageAccountContainerName": "images",
"vnetID": "[resourceId('Microsoft.Network/virtualNetworks', variables('virtualNetworkName'))]"
},
"resources": [
{
"apiVersion": "[variables('apiVersion')]",
"type": "Microsoft.Network/publicIPAddresses",
"name": "[variables('publicIPAddressName')]",
"location": "[variables('location')]",
"properties": {
"publicIPAllocationMethod": "[variables('publicIPAddressType')]",
"dnsSettings": {
"domainNameLabel": "[parameters('dnsNameForPublicIP')]"
}
}
},
{
"apiVersion": "[variables('apiVersion')]",
"type": "Microsoft.Network/virtualNetworks",
"name": "[variables('virtualNetworkName')]",
"location": "[variables('location')]",
"properties": {
"addressSpace": {
"addressPrefixes": [
"[variables('addressPrefix')]"
]
},
"subnets": [
{
"name": "[variables('subnetName')]",
"properties": {
"addressPrefix": "[variables('subnetAddressPrefix')]"
}
}
]
}
},
{
"apiVersion": "[variables('apiVersion')]",
"type": "Microsoft.Network/networkInterfaces",
"name": "[variables('nicName')]",
"location": "[variables('location')]",
"dependsOn": [
"[concat('Microsoft.Network/publicIPAddresses/', variables('publicIPAddressName'))]",
"[concat('Microsoft.Network/virtualNetworks/', variables('virtualNetworkName'))]"
],
"properties": {
"ipConfigurations": [
{
"name": "ipconfig",
"properties": {
"privateIPAllocationMethod": "Dynamic",
"publicIPAddress": {
"id": "[resourceId('Microsoft.Network/publicIPAddresses', variables('publicIPAddressName'))]"
},
"subnet": {
"id": "[variables('subnetRef')]"
}
}
}
]
}
},
{
"apiVersion": "[variables('apiVersion')]",
"type": "Microsoft.Compute/virtualMachines",
"name": "[parameters('vmName')]",
"location": "[variables('location')]",
"dependsOn": [
"[concat('Microsoft.Network/networkInterfaces/', variables('nicName'))]"
],
"properties": {
"hardwareProfile": {
"vmSize": "[parameters('vmSize')]"
},
"osProfile": {
"computerName": "[parameters('vmName')]",
"adminUsername": "[parameters('adminUsername')]",
"adminPassword": "[parameters('adminPassword')]",
"secrets": [
{
"sourceVault": {
"id": "[resourceId(resourceGroup().name, 'Microsoft.KeyVault/vaults', parameters('keyVaultName'))]"
},
"vaultCertificates": [
{
"certificateUrl": "[parameters('winRMCertificateUrl')]",
"certificateStore": "My"
}
]
}
],
"windowsConfiguration": {
"provisionVMAgent": "true",
"winRM": {
"listeners": [{
"protocol": "https",
"certificateUrl": "[parameters('winRMCertificateUrl')]"
}
]
},
"enableAutomaticUpdates": "true"
}
},
"storageProfile": {
"imageReference": {
"publisher": "[parameters('imagePublisher')]",
"offer": "[parameters('imageOffer')]",
"sku": "[parameters('imageSku')]",
"version": "[parameters('imageVersion')]"
},
"osDisk": {
"name": "osdisk",
"vhd": {
"uri": "[concat(parameters('storageAccountBlobEndpoint'),variables('vmStorageAccountContainerName'),'/', parameters('osDiskName'),'.vhd')]"
},
"caching": "ReadWrite",
"createOption": "FromImage"
}
},
"networkProfile": {
"networkInterfaces": [
{
"id": "[resourceId('Microsoft.Network/networkInterfaces', variables('nicName'))]"
}
]
},
"diagnosticsProfile": {
"bootDiagnostics": {
"enabled": "false"
}
}
}
}
]
}`

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
@ -21,15 +21,21 @@ type TemplateParameter struct {
}
type TemplateParameters struct {
AdminUsername *TemplateParameter `json:"adminUsername,omitempty"`
AdminPassword *TemplateParameter `json:"adminPassword,omitempty"`
DnsNameForPublicIP *TemplateParameter `json:"dnsNameForPublicIP,omitempty"`
ImageOffer *TemplateParameter `json:"imageOffer,omitempty"`
ImagePublisher *TemplateParameter `json:"imagePublisher,omitempty"`
ImageSku *TemplateParameter `json:"imageSku,omitempty"`
OSDiskName *TemplateParameter `json:"osDiskName,omitempty"`
SshAuthorizedKey *TemplateParameter `json:"sshAuthorizedKey,omitempty"`
StorageAccountName *TemplateParameter `json:"storageAccountName,omitempty"`
VMSize *TemplateParameter `json:"vmSize,omitempty"`
VMName *TemplateParameter `json:"vmName,omitempty"`
AdminUsername *TemplateParameter `json:"adminUsername,omitempty"`
AdminPassword *TemplateParameter `json:"adminPassword,omitempty"`
DnsNameForPublicIP *TemplateParameter `json:"dnsNameForPublicIP,omitempty"`
ImageOffer *TemplateParameter `json:"imageOffer,omitempty"`
ImagePublisher *TemplateParameter `json:"imagePublisher,omitempty"`
ImageSku *TemplateParameter `json:"imageSku,omitempty"`
ImageVersion *TemplateParameter `json:"imageVersion,omitempty"`
KeyVaultName *TemplateParameter `json:"keyVaultName,omitempty"`
KeyVaultSecretValue *TemplateParameter `json:"keyVaultSecretValue,omitempty"`
ObjectId *TemplateParameter `json:"objectId,omitempty"`
OSDiskName *TemplateParameter `json:"osDiskName,omitempty"`
SshAuthorizedKey *TemplateParameter `json:"sshAuthorizedKey,omitempty"`
StorageAccountBlobEndpoint *TemplateParameter `json:"storageAccountBlobEndpoint,omitempty"`
TenantId *TemplateParameter `json:"tenantId,omitempty"`
VMSize *TemplateParameter `json:"vmSize,omitempty"`
VMName *TemplateParameter `json:"vmName,omitempty"`
WinRMCertificateUrl *TemplateParameter `json:"winRMCertificateUrl,omitempty"`
}

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
@ -12,17 +12,17 @@ import (
func TestTemplateParametersShouldHaveExpectedKeys(t *testing.T) {
params := TemplateParameters{
AdminUsername: &TemplateParameter{"sentinel"},
AdminPassword: &TemplateParameter{"sentinel"},
DnsNameForPublicIP: &TemplateParameter{"sentinel"},
ImageOffer: &TemplateParameter{"sentinel"},
ImagePublisher: &TemplateParameter{"sentinel"},
ImageSku: &TemplateParameter{"sentinel"},
OSDiskName: &TemplateParameter{"sentinel"},
SshAuthorizedKey: &TemplateParameter{"sentinel"},
StorageAccountName: &TemplateParameter{"sentinel"},
VMName: &TemplateParameter{"sentinel"},
VMSize: &TemplateParameter{"sentinel"},
AdminUsername: &TemplateParameter{"sentinel"},
AdminPassword: &TemplateParameter{"sentinel"},
DnsNameForPublicIP: &TemplateParameter{"sentinel"},
ImageOffer: &TemplateParameter{"sentinel"},
ImagePublisher: &TemplateParameter{"sentinel"},
ImageSku: &TemplateParameter{"sentinel"},
OSDiskName: &TemplateParameter{"sentinel"},
SshAuthorizedKey: &TemplateParameter{"sentinel"},
StorageAccountBlobEndpoint: &TemplateParameter{"sentinel"},
VMName: &TemplateParameter{"sentinel"},
VMSize: &TemplateParameter{"sentinel"},
}
bs, err := json.Marshal(params)
@ -46,7 +46,7 @@ func TestTemplateParametersShouldHaveExpectedKeys(t *testing.T) {
"imageSku",
"osDiskName",
"sshAuthorizedKey",
"storageAccountName",
"storageAccountBlobEndpoint",
"vmSize",
"vmName",
}
@ -61,17 +61,17 @@ func TestTemplateParametersShouldHaveExpectedKeys(t *testing.T) {
func TestParameterValuesShouldBeSet(t *testing.T) {
params := TemplateParameters{
AdminUsername: &TemplateParameter{"adminusername00"},
AdminPassword: &TemplateParameter{"adminpassword00"},
DnsNameForPublicIP: &TemplateParameter{"dnsnameforpublicip00"},
ImageOffer: &TemplateParameter{"imageoffer00"},
ImagePublisher: &TemplateParameter{"imagepublisher00"},
ImageSku: &TemplateParameter{"imagesku00"},
OSDiskName: &TemplateParameter{"osdiskname00"},
SshAuthorizedKey: &TemplateParameter{"sshauthorizedkey00"},
StorageAccountName: &TemplateParameter{"storageaccountname00"},
VMName: &TemplateParameter{"vmname00"},
VMSize: &TemplateParameter{"vmsize00"},
AdminUsername: &TemplateParameter{"adminusername00"},
AdminPassword: &TemplateParameter{"adminpassword00"},
DnsNameForPublicIP: &TemplateParameter{"dnsnameforpublicip00"},
ImageOffer: &TemplateParameter{"imageoffer00"},
ImagePublisher: &TemplateParameter{"imagepublisher00"},
ImageSku: &TemplateParameter{"imagesku00"},
OSDiskName: &TemplateParameter{"osdiskname00"},
SshAuthorizedKey: &TemplateParameter{"sshauthorizedkey00"},
StorageAccountBlobEndpoint: &TemplateParameter{"storageaccountblobendpoint00"},
VMName: &TemplateParameter{"vmname00"},
VMSize: &TemplateParameter{"vmsize00"},
}
bs, err := json.Marshal(params)

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm
@ -15,11 +15,13 @@ const (
)
type TempName struct {
AdminPassword string
ComputeName string
DeploymentName string
ResourceGroupName string
OSDiskName string
AdminPassword string
CertificatePassword string
ComputeName string
DeploymentName string
KeyVaultName string
ResourceGroupName string
OSDiskName string
}
func NewTempName() *TempName {
@ -28,10 +30,12 @@ func NewTempName() *TempName {
suffix := common.RandomString(TempNameAlphabet, 10)
tempName.ComputeName = fmt.Sprintf("pkrvm%s", suffix)
tempName.DeploymentName = fmt.Sprintf("pkrdp%s", suffix)
tempName.KeyVaultName = fmt.Sprintf("pkrkv%s", suffix)
tempName.OSDiskName = fmt.Sprintf("pkros%s", suffix)
tempName.ResourceGroupName = fmt.Sprintf("packer-Resource-Group-%s", suffix)
tempName.AdminPassword = common.RandomString(TempPasswordAlphabet, 32)
tempName.CertificatePassword = common.RandomString(TempPasswordAlphabet, 32)
return tempName
}

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package arm

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package constants

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package constants
@ -29,12 +29,18 @@ const (
Ui string = "ui"
)
const (
ArmBlobEndpoint string = "arm.BlobEndpoint"
ArmCaptureTemplate string = "arm.CaptureTemplate"
ArmComputeName string = "arm.ComputeName"
ArmCertificateUrl string = "arm.CertificateUrl"
ArmDeploymentName string = "arm.DeploymentName"
ArmKeyVaultName string = "arm.KeyVaultName"
ArmLocation string = "arm.Location"
ArmOSDiskVhd string = "arm.OSDiskVhd"
ArmPublicIPAddressName string = "arm.PublicIPAddressName"
ArmResourceGroupName string = "arm.ResourceGroupName"
ArmIsResourceGroupCreated string = "arm.IsResourceGroupCreated"
ArmStorageAccountName string = "arm.StorageAccountName"
ArmTemplateParameters string = "arm.TemplateParameters"
ArmVirtualMachineCaptureParameters string = "arm.VirtualMachineCaptureParameters"
)

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package constants

View File

@ -0,0 +1,247 @@
package common
import (
"fmt"
"net/http"
"os"
"path/filepath"
"regexp"
"github.com/Azure/azure-sdk-for-go/arm/resources/subscriptions"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/Azure/go-autorest/autorest/to"
"github.com/mitchellh/packer/version"
"github.com/mitchellh/go-homedir"
)
var (
// AD app id for packer-azure driver.
clientIDs = map[string]string{
azure.PublicCloud.Name: "04cc58ec-51ab-4833-ac0d-ce3a7912414b",
}
userAgent = fmt.Sprintf("packer/%s", version.FormattedVersion())
)
// NOTE(ahmetalpbalkan): Azure Active Directory implements OAuth 2.0 Device Flow
// described here: https://tools.ietf.org/html/draft-denniss-oauth-device-flow-00
// Although it has some gotchas, most of the authentication logic is in Azure SDK
// for Go helper packages.
//
// Device auth prints a message to the screen telling the user to click on URL
// and approve the app on the browser, meanwhile the client polls the auth API
// for a token. Once we have token, we save it locally to a file with proper
// permissions and when the token expires (in Azure case typically 1 hour) SDK
// will automatically refresh the specified token and will call the refresh
// callback function we implement here. This way we will always be storing a
// token with a refresh_token saved on the machine.
// Authenticate fetches a token from the local file cache or initiates a consent
// flow and waits for token to be obtained.
func Authenticate(env azure.Environment, subscriptionID string, say func(string)) (*azure.ServicePrincipalToken, error) {
clientID, ok := clientIDs[env.Name]
if !ok {
return nil, fmt.Errorf("packer-azure application not set up for Azure environment %q", env.Name)
}
// First we locate the tenant ID of the subscription as we store tokens per
// tenant (which could have multiple subscriptions)
say(fmt.Sprintf("Looking up AAD Tenant ID: subscriptionID=%s.", subscriptionID))
tenantID, err := findTenantID(env, subscriptionID)
if err != nil {
return nil, err
}
say(fmt.Sprintf("Found AAD Tenant ID: tenantID=%s", tenantID))
oauthCfg, err := env.OAuthConfigForTenant(tenantID)
if err != nil {
return nil, fmt.Errorf("Failed to obtain oauth config for azure environment: %v", err)
}
// for AzurePublicCloud (https://management.core.windows.net/), this old
// Service Management scope covers both ASM and ARM.
apiScope := env.ServiceManagementEndpoint
tokenPath := tokenCachePath(tenantID)
saveToken := mkTokenCallback(tokenPath)
saveTokenCallback := func(t azure.Token) error {
say("Azure token expired. Saving the refreshed token...")
return saveToken(t)
}
// Lookup the token cache file for an existing token.
spt, err := tokenFromFile(say, *oauthCfg, tokenPath, clientID, apiScope, saveTokenCallback)
if err != nil {
return nil, err
}
if spt != nil {
say(fmt.Sprintf("Auth token found in file: %s", tokenPath))
// NOTE(ahmetalpbalkan): The token file we found may contain an
// expired access_token. In that case, the first call to Azure SDK will
// attempt to refresh the token using refresh_token, which might have
// expired[1], in that case we will get an error and we shall remove the
// token file and initiate token flow again so that the user would not
// need removing the token cache file manually.
//
// [1]: expiration date of refresh_token is not returned in AAD /token
// response, we just know it is 14 days. Therefore users token
// will go stale every 14 days and we will delete the token file,
// re-initiate the device flow.
say("Validating the token.")
if err := validateToken(env, spt); err != nil {
say(fmt.Sprintf("Error: %v", err))
say("Stored Azure credentials expired. Please reauthenticate.")
say(fmt.Sprintf("Deleting %s", tokenPath))
if err := os.RemoveAll(tokenPath); err != nil {
return nil, fmt.Errorf("Error deleting stale token file: %v", err)
}
} else {
say("Token works.")
return spt, nil
}
}
// Start an OAuth 2.0 device flow
say(fmt.Sprintf("Initiating device flow: %s", tokenPath))
spt, err = tokenFromDeviceFlow(say, *oauthCfg, tokenPath, clientID, apiScope)
if err != nil {
return nil, err
}
say("Obtained service principal token.")
if err := saveToken(spt.Token); err != nil {
say("Error occurred saving token to cache file.")
return nil, err
}
return spt, nil
}
// tokenFromFile returns a token from the specified file if it is found, otherwise
// returns nil. Any error retrieving or creating the token is returned as an error.
func tokenFromFile(say func(string), oauthCfg azure.OAuthConfig, tokenPath, clientID, resource string,
callback azure.TokenRefreshCallback) (*azure.ServicePrincipalToken, error) {
say(fmt.Sprintf("Loading auth token from file: %s", tokenPath))
if _, err := os.Stat(tokenPath); err != nil {
if os.IsNotExist(err) { // file not found
return nil, nil
}
return nil, err
}
token, err := azure.LoadToken(tokenPath)
if err != nil {
return nil, fmt.Errorf("Failed to load token from file: %v", err)
}
spt, err := azure.NewServicePrincipalTokenFromManualToken(oauthCfg, clientID, resource, *token, callback)
if err != nil {
return nil, fmt.Errorf("Error constructing service principal token: %v", err)
}
return spt, nil
}
// tokenFromDeviceFlow prints a message to the screen for user to take action to
// consent application on a browser and in the meanwhile the authentication
// endpoint is polled until user gives consent, denies or the flow times out.
// Returned token must be saved.
func tokenFromDeviceFlow(say func(string), oauthCfg azure.OAuthConfig, tokenPath, clientID, resource string) (*azure.ServicePrincipalToken, error) {
cl := autorest.NewClientWithUserAgent(userAgent)
deviceCode, err := azure.InitiateDeviceAuth(&cl, oauthCfg, clientID, resource)
if err != nil {
return nil, fmt.Errorf("Failed to start device auth: %v", err)
}
// Example message: “To sign in, open https://aka.ms/devicelogin and enter
// the code 0000000 to authenticate.”
say(fmt.Sprintf("Microsoft Azure: %s", to.String(deviceCode.Message)))
token, err := azure.WaitForUserCompletion(&cl, deviceCode)
if err != nil {
return nil, fmt.Errorf("Failed to complete device auth: %v", err)
}
spt, err := azure.NewServicePrincipalTokenFromManualToken(oauthCfg, clientID, resource, *token)
if err != nil {
return nil, fmt.Errorf("Error constructing service principal token: %v", err)
}
return spt, nil
}
// tokenCachePath returns the full path the OAuth 2.0 token should be saved at
// for given tenant ID.
func tokenCachePath(tenantID string) string {
dir, err := homedir.Dir()
if err != nil {
dir, _ = filepath.Abs(os.Args[0])
}
return filepath.Join(dir, ".azure", "packer", fmt.Sprintf("oauth-%s.json", tenantID))
}
// mkTokenCallback returns a callback function that can be used to save the
// token initially or register to the Azure SDK to be called when the token is
// refreshed.
func mkTokenCallback(path string) azure.TokenRefreshCallback {
return func(t azure.Token) error {
if err := azure.SaveToken(path, 0600, t); err != nil {
return err
}
return nil
}
}
// validateToken makes a call to Azure SDK with given token, essentially making
// sure if the access_token valid, if not it uses SDKs functionality to
// automatically refresh the token using refresh_token (which might have
// expired). This check is essentially to make sure refresh_token is good.
func validateToken(env azure.Environment, token *azure.ServicePrincipalToken) error {
c := subscriptionsClient(env.ResourceManagerEndpoint)
// WTF(chrboum)
//c.Authorizer = token
_, err := c.List()
if err != nil {
return fmt.Errorf("Token validity check failed: %v", err)
}
return nil
}
// findTenantID figures out the AAD tenant ID of the subscription by making an
// unauthenticated request to the Get Subscription Details endpoint and parses
// the value from WWW-Authenticate header.
func findTenantID(env azure.Environment, subscriptionID string) (string, error) {
const hdrKey = "WWW-Authenticate"
c := subscriptionsClient(env.ResourceManagerEndpoint)
// we expect this request to fail (err != nil), but we are only interested
// in headers, so surface the error if the Response is not present (i.e.
// network error etc)
subs, err := c.Get(subscriptionID)
if subs.Response.Response == nil {
return "", fmt.Errorf("Request failed: %v", err)
}
// Expecting 401 StatusUnauthorized here, just read the header
if subs.StatusCode != http.StatusUnauthorized {
return "", fmt.Errorf("Unexpected response from Get Subscription: %v", err)
}
hdr := subs.Header.Get(hdrKey)
if hdr == "" {
return "", fmt.Errorf("Header %v not found in Get Subscription response", hdrKey)
}
// Example value for hdr:
// Bearer authorization_uri="https://login.windows.net/996fe9d1-6171-40aa-945b-4c64b63bf655", error="invalid_token", error_description="The authentication failed because of missing 'Authorization' header."
r := regexp.MustCompile(`authorization_uri=".*/([0-9a-f\-]+)"`)
m := r.FindStringSubmatch(hdr)
if m == nil {
return "", fmt.Errorf("Could not find the tenant ID in header: %s %q", hdrKey, hdr)
}
return m[1], nil
}
func subscriptionsClient(baseURI string) subscriptions.Client {
c := subscriptions.NewClientWithBaseURI(baseURI, "") // used only for unauthenticated requests for generic subs IDs
c.Client.UserAgent += userAgent
return c
}

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package common

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package common

View File

@ -0,0 +1,57 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package common
import (
"time"
)
type InterruptibleTaskResult struct {
Err error
IsCancelled bool
}
type InterruptibleTask struct {
IsCancelled func() bool
Task func(cancelCh <-chan struct{}) error
}
func NewInterruptibleTask(isCancelled func() bool, task func(cancelCh <-chan struct{}) error) *InterruptibleTask {
return &InterruptibleTask{
IsCancelled: isCancelled,
Task: task,
}
}
func StartInterruptibleTask(isCancelled func() bool, task func(cancelCh <-chan struct{}) error) InterruptibleTaskResult {
t := NewInterruptibleTask(isCancelled, task)
return t.Run()
}
func (s *InterruptibleTask) Run() InterruptibleTaskResult {
completeCh := make(chan error)
cancelCh := make(chan struct{})
defer close(cancelCh)
go func() {
err := s.Task(cancelCh)
completeCh <- err
// senders close, receivers check for close
close(completeCh)
}()
for {
if s.IsCancelled() {
return InterruptibleTaskResult{Err: nil, IsCancelled: true}
}
select {
case err := <-completeCh:
return InterruptibleTaskResult{Err: err, IsCancelled: false}
case <-time.After(100 * time.Millisecond):
}
}
}

View File

@ -0,0 +1,105 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package common
import (
"fmt"
"testing"
"time"
)
func TestInterruptibleTaskShouldImmediatelyEndOnCancel(t *testing.T) {
testSubject := NewInterruptibleTask(
func() bool { return true },
func(<-chan struct{}) error {
for {
time.Sleep(time.Second * 30)
}
})
result := testSubject.Run()
if result.IsCancelled != true {
t.Fatalf("Expected the task to be cancelled, but it was not.")
}
}
func TestInterruptibleTaskShouldRunTaskUntilCompletion(t *testing.T) {
var count int
testSubject := &InterruptibleTask{
IsCancelled: func() bool {
return false
},
Task: func(<-chan struct{}) error {
for i := 0; i < 10; i++ {
count += 1
}
return nil
},
}
result := testSubject.Run()
if result.IsCancelled != false {
t.Errorf("Expected the task to *not* be cancelled, but it was.")
}
if count != 10 {
t.Errorf("Expected the task to wait for completion, but it did not.")
}
if result.Err != nil {
t.Errorf("Expected the task to return a nil error, but got=%s", result.Err)
}
}
func TestInterruptibleTaskShouldImmediatelyStopOnTaskError(t *testing.T) {
testSubject := &InterruptibleTask{
IsCancelled: func() bool {
return false
},
Task: func(cancelCh <-chan struct{}) error {
return fmt.Errorf("boom")
},
}
result := testSubject.Run()
if result.IsCancelled != false {
t.Errorf("Expected the task to *not* be cancelled, but it was.")
}
if result.Err == nil {
t.Errorf("Expected the task to return an error, but it did not.")
}
}
func TestInterruptibleTaskShouldProvideLiveChannel(t *testing.T) {
testSubject := &InterruptibleTask{
IsCancelled: func() bool {
return false
},
Task: func(cancelCh <-chan struct{}) error {
isOpen := false
select {
case _, ok := <-cancelCh:
isOpen = !ok
if !isOpen {
t.Errorf("Expected the channel to open, but it was closed.")
}
default:
isOpen = true
break
}
if !isOpen {
t.Errorf("Check for openness failed.")
}
return nil
},
}
testSubject.Run()
}

View File

@ -1,13 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package lin
import (
"fmt"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/builder/azure/common/constants"
"github.com/mitchellh/multistep"
"golang.org/x/crypto/ssh"
)

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package lin
@ -28,11 +28,11 @@ type StepCreateCert struct {
func (s *StepCreateCert) Run(state multistep.StateBag) multistep.StepAction {
ui := state.Get("ui").(packer.Ui)
ui.Say("Creating Temporary Certificate...")
ui.Say("Creating temporary certificate...")
err := s.createCert(state)
if err != nil {
err := fmt.Errorf("Error Creating Temporary Certificate: %s", err)
err := fmt.Errorf("Error creating temporary certificate: %s", err)
state.Put("error", err)
ui.Error(err.Error())
return multistep.ActionHalt

View File

@ -1,15 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package lin
import (
"bytes"
"fmt"
"log"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/packer"
"log"
)
type StepGeneralizeOS struct {

View File

@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package logutil
import "fmt"
type Fields map[string]interface{}
func (f Fields) String() string {
var s string
for k, v := range f {
if sv, ok := v.(string); ok {
v = fmt.Sprintf("%q", sv)
}
s += fmt.Sprintf(" %s=%v", k, v)
}
return s
}

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package common

View File

@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in builder/azure for license information.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package common

View File

@ -0,0 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
package common
import "github.com/mitchellh/multistep"
func IsStateCancelled(stateBag multistep.StateBag) bool {
_, ok := stateBag.GetOk(multistep.StateCancelled)
return ok
}

View File

@ -0,0 +1,74 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See the LICENSE file in the project root for license information.
// NOTE: vault APIs do not yet exist in the SDK, but once they do this code
// should be removed.
package common
import (
"net/http"
"strings"
"github.com/Azure/go-autorest/autorest"
)
const (
AzureVaultApiVersion = "2015-06-01"
AzureVaultScope = "https://vault.azure.net"
AzureVaultSecretTemplate = "https://{vault-name}.vault.azure.net/secrets/{secret-name}"
)
type VaultClient struct {
autorest.Client
}
type Secret struct {
ID *string `json:"id,omitempty"`
Value string `json:"value"`
Attributes SecretAttributes `json:"attributes"`
}
type SecretAttributes struct {
Enabled bool `json:"enabled"`
Created *string `json:"created"`
Updated *string `json:"updated"`
}
func (client *VaultClient) GetSecret(vaultName, secretName string) (*Secret, error) {
p := map[string]interface{}{
"secret-name": secretName,
}
q := map[string]interface{}{
"api-version": AzureVaultApiVersion,
}
secretURL := strings.Replace(AzureVaultSecretTemplate, "{vault-name}", vaultName, -1)
req, err := autorest.Prepare(&http.Request{},
autorest.AsGet(),
autorest.WithBaseURL(secretURL),
autorest.WithPathParameters(p),
autorest.WithQueryParameters(q))
if err != nil {
return nil, err
}
//resp, err := v.Send(req, http.StatusOK)
resp, err := autorest.SendWithSender(client, req)
if err != nil {
return nil, err
}
var secret Secret
err = autorest.Respond(
resp,
autorest.ByUnmarshallingJSON(&secret))
if err != nil {
return nil, err
}
return &secret, nil
}

View File

@ -0,0 +1,9 @@
This is a fork of the from the original PKCS#12 parsing code
published in the Azure repository [go-pkcs12](https://github.com/Azure/go-pkcs12).
This fork adds serializing a x509 certificate and private key as PKCS#12 binary blob
(aka .PFX file). Due to the specific nature of this code it was not accepted for
inclusion in the official Go crypto repository.
The methods used for decoding PKCS#12 have been moved to the test files to further
discourage the use of this library for decoding. Please use the official
[pkcs12](https://godoc.org/golang.org/x/crypto/pkcs12) library for decode support.

View File

@ -0,0 +1,26 @@
package pkcs12
import (
"errors"
"unicode/utf16"
)
func bmpString(s string) ([]byte, error) {
// References:
// https://tools.ietf.org/html/rfc7292#appendix-B.1
// http://en.wikipedia.org/wiki/Plane_(Unicode)#Basic_Multilingual_Plane
// - non-BMP characters are encoded in UTF 16 by using a surrogate pair of 16-bit codes
// EncodeRune returns 0xfffd if the rune does not need special encoding
// - the above RFC provides the info that BMPStrings are NULL terminated.
rv := make([]byte, 0, 2*len(s)+2)
for _, r := range s {
if t, _ := utf16.EncodeRune(r); t != 0xfffd {
return nil, errors.New("string contains characters that cannot be encoded in UCS-2")
}
rv = append(rv, byte(r/256), byte(r%256))
}
rv = append(rv, 0, 0)
return rv, nil
}

View File

@ -0,0 +1,70 @@
package pkcs12
import (
"bytes"
"errors"
"testing"
"unicode/utf16"
)
func decodeBMPString(bmpString []byte) (string, error) {
if len(bmpString)%2 != 0 {
return "", errors.New("expected BMP byte string to be an even length")
}
// strip terminator if present
if terminator := bmpString[len(bmpString)-2:]; terminator[0] == terminator[1] && terminator[1] == 0 {
bmpString = bmpString[:len(bmpString)-2]
}
s := make([]uint16, 0, len(bmpString)/2)
for len(bmpString) > 0 {
s = append(s, uint16(bmpString[0])*265+uint16(bmpString[1]))
bmpString = bmpString[2:]
}
return string(utf16.Decode(s)), nil
}
func TestBMPStringDecode(t *testing.T) {
_, err := decodeBMPString([]byte("a"))
if err == nil {
t.Fatalf("expected decode to fail, but it succeeded")
}
}
func TestBMPString(t *testing.T) {
str, err := bmpString("")
if bytes.Compare(str, []byte{0, 0}) != 0 {
t.Errorf("expected empty string to return double 0, but found: % x", str)
}
if err != nil {
t.Errorf("err: %v", err)
}
// Example from https://tools.ietf.org/html/rfc7292#appendix-B
str, err = bmpString("Beavis")
if bytes.Compare(str, []byte{0x00, 0x42, 0x00, 0x65, 0x00, 0x61, 0x00, 0x0076, 0x00, 0x69, 0x00, 0x73, 0x00, 0x00}) != 0 {
t.Errorf("expected 'Beavis' to return 0x00 0x42 0x00 0x65 0x00 0x61 0x00 0x76 0x00 0x69 0x00 0x73 0x00 0x00, but found: % x", str)
}
if err != nil {
t.Errorf("err: %v", err)
}
// some characters from the "Letterlike Symbols Unicode block"
tst := "\u2115 - Double-struck N"
str, err = bmpString(tst)
if bytes.Compare(str, []byte{0x21, 0x15, 0x00, 0x20, 0x00, 0x2d, 0x00, 0x20, 0x00, 0x44, 0x00, 0x6f, 0x00, 0x75, 0x00, 0x62, 0x00, 0x6c, 0x00, 0x65, 0x00, 0x2d, 0x00, 0x73, 0x00, 0x74, 0x00, 0x72, 0x00, 0x75, 0x00, 0x63, 0x00, 0x6b, 0x00, 0x20, 0x00, 0x4e, 0x00, 0x00}) != 0 {
t.Errorf("expected '%s' to return 0x21 0x15 0x00 0x20 0x00 0x2d 0x00 0x20 0x00 0x44 0x00 0x6f 0x00 0x75 0x00 0x62 0x00 0x6c 0x00 0x65 0x00 0x2d 0x00 0x73 0x00 0x74 0x00 0x72 0x00 0x75 0x00 0x63 0x00 0x6b 0x00 0x20 0x00 0x4e 0x00 0x00, but found: % x", tst, str)
}
if err != nil {
t.Errorf("err: %v", err)
}
// some character outside the BMP should error
tst = "\U0001f000 East wind (Mahjong)"
str, err = bmpString(tst)
if err == nil {
t.Errorf("expected '%s' to throw error because the first character is not in the BMP", tst)
}
}

View File

@ -0,0 +1,82 @@
// implementation of https://tools.ietf.org/html/rfc2898#section-6.1.2
package pkcs12
import (
"bytes"
"crypto/cipher"
"crypto/des"
"crypto/rand"
"crypto/x509/pkix"
"encoding/asn1"
"errors"
"io"
"github.com/mitchellh/packer/builder/azure/pkcs12/rc2"
)
const (
pbeWithSHAAnd3KeyTripleDESCBC = "pbeWithSHAAnd3-KeyTripleDES-CBC"
pbewithSHAAnd40BitRC2CBC = "pbewithSHAAnd40BitRC2-CBC"
)
const (
pbeIterationCount = 2048
pbeSaltSizeBytes = 8
)
var (
oidPbeWithSHAAnd3KeyTripleDESCBC = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 12, 1, 3}
oidPbewithSHAAnd40BitRC2CBC = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 12, 1, 6}
)
var algByOID = map[string]string{
oidPbeWithSHAAnd3KeyTripleDESCBC.String(): pbeWithSHAAnd3KeyTripleDESCBC,
oidPbewithSHAAnd40BitRC2CBC.String(): pbewithSHAAnd40BitRC2CBC,
}
var blockcodeByAlg = map[string]func(key []byte) (cipher.Block, error){
pbeWithSHAAnd3KeyTripleDESCBC: des.NewTripleDESCipher,
pbewithSHAAnd40BitRC2CBC: func(key []byte) (cipher.Block, error) {
return rc2.New(key, len(key)*8)
},
}
type pbeParams struct {
Salt []byte
Iterations int
}
func pad(src []byte, blockSize int) []byte {
paddingLength := blockSize - len(src)%blockSize
paddingText := bytes.Repeat([]byte{byte(paddingLength)}, paddingLength)
return append(src, paddingText...)
}
func pbEncrypt(plainText, salt, password []byte, iterations int) (cipherText []byte, err error) {
_, err = io.ReadFull(rand.Reader, salt)
if err != nil {
return nil, errors.New("pkcs12: failed to create a random salt value: " + err.Error())
}
key := deriveKeyByAlg[pbeWithSHAAnd3KeyTripleDESCBC](salt, password, iterations)
iv := deriveIVByAlg[pbeWithSHAAnd3KeyTripleDESCBC](salt, password, iterations)
block, err := des.NewTripleDESCipher(key)
if err != nil {
return nil, errors.New("pkcs12: failed to create a block cipher: " + err.Error())
}
paddedPlainText := pad(plainText, block.BlockSize())
encrypter := cipher.NewCBCEncrypter(block, iv)
cipherText = make([]byte, len(paddedPlainText))
encrypter.CryptBlocks(cipherText, paddedPlainText)
return cipherText, nil
}
type decryptable interface {
GetAlgorithm() pkix.AlgorithmIdentifier
GetData() []byte
}

View File

@ -0,0 +1,159 @@
package pkcs12
import (
"bytes"
"crypto/cipher"
"crypto/x509/pkix"
"encoding/asn1"
"testing"
)
func pbDecrypterFor(algorithm pkix.AlgorithmIdentifier, password []byte) (cipher.BlockMode, error) {
algorithmName, supported := algByOID[algorithm.Algorithm.String()]
if !supported {
return nil, NotImplementedError("algorithm " + algorithm.Algorithm.String() + " is not supported")
}
var params pbeParams
if _, err := asn1.Unmarshal(algorithm.Parameters.FullBytes, &params); err != nil {
return nil, err
}
k := deriveKeyByAlg[algorithmName](params.Salt, password, params.Iterations)
iv := deriveIVByAlg[algorithmName](params.Salt, password, params.Iterations)
password = nil
code, err := blockcodeByAlg[algorithmName](k)
if err != nil {
return nil, err
}
cbc := cipher.NewCBCDecrypter(code, iv)
return cbc, nil
}
func pbDecrypt(info decryptable, password []byte) (decrypted []byte, err error) {
cbc, err := pbDecrypterFor(info.GetAlgorithm(), password)
password = nil
if err != nil {
return nil, err
}
encrypted := info.GetData()
decrypted = make([]byte, len(encrypted))
cbc.CryptBlocks(decrypted, encrypted)
if psLen := int(decrypted[len(decrypted)-1]); psLen > 0 && psLen <= cbc.BlockSize() {
m := decrypted[:len(decrypted)-psLen]
ps := decrypted[len(decrypted)-psLen:]
if bytes.Compare(ps, bytes.Repeat([]byte{byte(psLen)}, psLen)) != 0 {
return nil, ErrDecryption
}
decrypted = m
} else {
return nil, ErrDecryption
}
return
}
func TestPbDecrypterFor(t *testing.T) {
params, _ := asn1.Marshal(pbeParams{
Salt: []byte{1, 2, 3, 4, 5, 6, 7, 8},
Iterations: 2048,
})
alg := pkix.AlgorithmIdentifier{
Algorithm: asn1.ObjectIdentifier([]int{1, 2, 3}),
Parameters: asn1.RawValue{
FullBytes: params,
},
}
pass, _ := bmpString("Sesame open")
_, err := pbDecrypterFor(alg, pass)
if _, ok := err.(NotImplementedError); !ok {
t.Errorf("expected not implemented error, got: %T %s", err, err)
}
alg.Algorithm = asn1.ObjectIdentifier([]int{1, 2, 840, 113549, 1, 12, 1, 3})
cbc, err := pbDecrypterFor(alg, pass)
if err != nil {
t.Errorf("err: %v", err)
}
M := []byte{1, 2, 3, 4, 5, 6, 7, 8}
expectedM := []byte{185, 73, 135, 249, 137, 1, 122, 247}
cbc.CryptBlocks(M, M)
if bytes.Compare(M, expectedM) != 0 {
t.Errorf("expected M to be '%d', but found '%d", expectedM, M)
}
}
func TestPbDecrypt(t *testing.T) {
tests := [][]byte{
[]byte("\x33\x73\xf3\x9f\xda\x49\xae\xfc\xa0\x9a\xdf\x5a\x58\xa0\xea\x46"), // 7 padding bytes
[]byte("\x33\x73\xf3\x9f\xda\x49\xae\xfc\x96\x24\x2f\x71\x7e\x32\x3f\xe7"), // 8 padding bytes
[]byte("\x35\x0c\xc0\x8d\xab\xa9\x5d\x30\x7f\x9a\xec\x6a\xd8\x9b\x9c\xd9"), // 9 padding bytes, incorrect
[]byte("\xb2\xf9\x6e\x06\x60\xae\x20\xcf\x08\xa0\x7b\xd9\x6b\x20\xef\x41"), // incorrect padding bytes: [ ... 0x04 0x02 ]
}
expected := []interface{}{
[]byte("A secret!"),
[]byte("A secret"),
ErrDecryption,
ErrDecryption,
}
for i, c := range tests {
td := testDecryptable{
data: c,
algorithm: pkix.AlgorithmIdentifier{
Algorithm: asn1.ObjectIdentifier([]int{1, 2, 840, 113549, 1, 12, 1, 3}), // SHA1/3TDES
Parameters: pbeParams{
Salt: []byte("\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8"),
Iterations: 4096,
}.RawASN1(),
},
}
p, _ := bmpString("sesame")
m, err := pbDecrypt(td, p)
switch e := expected[i].(type) {
case []byte:
if err != nil {
t.Errorf("error decrypting C=%x: %v", c, err)
}
if bytes.Compare(m, e) != 0 {
t.Errorf("expected C=%x to be decoded to M=%x, but found %x", c, e, m)
}
case error:
if err == nil || err.Error() != e.Error() {
t.Errorf("expecting error '%v' during decryption of c=%x, but found err='%v'", e, c, err)
}
}
}
}
type testDecryptable struct {
data []byte
algorithm pkix.AlgorithmIdentifier
}
func (d testDecryptable) GetAlgorithm() pkix.AlgorithmIdentifier { return d.algorithm }
func (d testDecryptable) GetData() []byte { return d.data }
func (params pbeParams) RawASN1() (raw asn1.RawValue) {
asn1Bytes, err := asn1.Marshal(params)
if err != nil {
panic(err)
}
_, err = asn1.Unmarshal(asn1Bytes, &raw)
if err != nil {
panic(err)
}
return
}

View File

@ -0,0 +1,24 @@
package pkcs12
import "errors"
var (
// ErrDecryption represents a failure to decrypt the input.
ErrDecryption = errors.New("pkcs12: decryption error, incorrect padding")
// ErrIncorrectPassword is returned when an incorrect password is detected.
// Usually, P12/PFX data is signed to be able to verify the password.
ErrIncorrectPassword = errors.New("pkcs12: decryption password incorrect")
)
// NotImplementedError indicates that the input is not currently supported.
type NotImplementedError string
type EncodeError string
func (e NotImplementedError) Error() string {
return string(e)
}
func (e EncodeError) Error() string {
return "pkcs12: encode error: " + string(e)
}

View File

@ -0,0 +1,33 @@
package pkcs12
import (
"crypto/hmac"
"crypto/sha1"
"crypto/x509/pkix"
"encoding/asn1"
)
var (
oidSha1Algorithm = asn1.ObjectIdentifier{1, 3, 14, 3, 2, 26}
)
type macData struct {
Mac digestInfo
MacSalt []byte
Iterations int `asn1:"optional,default:1"`
}
// from PKCS#7:
type digestInfo struct {
Algorithm pkix.AlgorithmIdentifier
Digest []byte
}
func computeMac(message []byte, iterations int, salt, password []byte) []byte {
key := pbkdf(sha1Sum, 20, 64, salt, password, iterations, 3, 20)
mac := hmac.New(sha1.New, key)
mac.Write(message)
return mac.Sum(nil)
}

View File

@ -0,0 +1,52 @@
package pkcs12
import (
"crypto/hmac"
"encoding/asn1"
"testing"
)
func verifyMac(macData *macData, message, password []byte) error {
if !macData.Mac.Algorithm.Algorithm.Equal(oidSha1Algorithm) {
return NotImplementedError("unknown digest algorithm: " + macData.Mac.Algorithm.Algorithm.String())
}
expectedMAC := computeMac(message, macData.Iterations, macData.MacSalt, password)
if !hmac.Equal(macData.Mac.Digest, expectedMAC) {
return ErrIncorrectPassword
}
return nil
}
func TestVerifyMac(t *testing.T) {
td := macData{
Mac: digestInfo{
Digest: []byte{0x18, 0x20, 0x3d, 0xff, 0x1e, 0x16, 0xf4, 0x92, 0xf2, 0xaf, 0xc8, 0x91, 0xa9, 0xba, 0xd6, 0xca, 0x9d, 0xee, 0x51, 0x93},
},
MacSalt: []byte{1, 2, 3, 4, 5, 6, 7, 8},
Iterations: 2048,
}
message := []byte{11, 12, 13, 14, 15}
password, _ := bmpString("")
td.Mac.Algorithm.Algorithm = asn1.ObjectIdentifier([]int{1, 2, 3})
err := verifyMac(&td, message, password)
if _, ok := err.(NotImplementedError); !ok {
t.Errorf("err: %v", err)
}
td.Mac.Algorithm.Algorithm = asn1.ObjectIdentifier([]int{1, 3, 14, 3, 2, 26})
err = verifyMac(&td, message, password)
if err != ErrIncorrectPassword {
t.Errorf("Expected incorrect password, got err: %v", err)
}
password, _ = bmpString("Sesame open")
err = verifyMac(&td, message, password)
if err != nil {
t.Errorf("err: %v", err)
}
}

View File

@ -0,0 +1,187 @@
package pkcs12
import (
"crypto/sha1"
"math/big"
)
var (
deriveKeyByAlg = map[string]func(salt, password []byte, iterations int) []byte{
pbeWithSHAAnd3KeyTripleDESCBC: func(salt, password []byte, iterations int) []byte {
return pbkdf(sha1Sum, 20, 64, salt, password, iterations, 1, 24)
},
pbewithSHAAnd40BitRC2CBC: func(salt, password []byte, iterations int) []byte {
return pbkdf(sha1Sum, 20, 64, salt, password, iterations, 1, 5)
},
}
deriveIVByAlg = map[string]func(salt, password []byte, iterations int) []byte{
pbeWithSHAAnd3KeyTripleDESCBC: func(salt, password []byte, iterations int) []byte {
return pbkdf(sha1Sum, 20, 64, salt, password, iterations, 2, 8)
},
pbewithSHAAnd40BitRC2CBC: func(salt, password []byte, iterations int) []byte {
return pbkdf(sha1Sum, 20, 64, salt, password, iterations, 2, 8)
},
}
)
func sha1Sum(in []byte) []byte {
sum := sha1.Sum(in)
return sum[:]
}
func pbkdf(hash func([]byte) []byte, u, v int, salt, password []byte, r int, ID byte, size int) (key []byte) {
// implementation of https://tools.ietf.org/html/rfc7292#appendix-B.2 , RFC text verbatim in comments
// Let H be a hash function built around a compression function f:
// Z_2^u x Z_2^v -> Z_2^u
// (that is, H has a chaining variable and output of length u bits, and
// the message input to the compression function of H is v bits). The
// values for u and v are as follows:
// HASH FUNCTION VALUE u VALUE v
// MD2, MD5 128 512
// SHA-1 160 512
// SHA-224 224 512
// SHA-256 256 512
// SHA-384 384 1024
// SHA-512 512 1024
// SHA-512/224 224 1024
// SHA-512/256 256 1024
// Furthermore, let r be the iteration count.
// We assume here that u and v are both multiples of 8, as are the
// lengths of the password and salt strings (which we denote by p and s,
// respectively) and the number n of pseudorandom bits required. In
// addition, u and v are of course non-zero.
// For information on security considerations for MD5 [19], see [25] and
// [1], and on those for MD2, see [18].
// The following procedure can be used to produce pseudorandom bits for
// a particular "purpose" that is identified by a byte called "ID".
// This standard specifies 3 different values for the ID byte:
// 1. If ID=1, then the pseudorandom bits being produced are to be used
// as key material for performing encryption or decryption.
// 2. If ID=2, then the pseudorandom bits being produced are to be used
// as an IV (Initial Value) for encryption or decryption.
// 3. If ID=3, then the pseudorandom bits being produced are to be used
// as an integrity key for MACing.
// 1. Construct a string, D (the "diversifier"), by concatenating v/8
// copies of ID.
D := []byte{}
for i := 0; i < v; i++ {
D = append(D, ID)
}
// 2. Concatenate copies of the salt together to create a string S of
// length v(ceiling(s/v)) bits (the final copy of the salt may be
// truncated to create S). Note that if the salt is the empty
// string, then so is S.
S := []byte{}
{
s := len(salt)
times := s / v
if s%v > 0 {
times++
}
for len(S) < times*v {
S = append(S, salt...)
}
S = S[:times*v]
}
// 3. Concatenate copies of the password together to create a string P
// of length v(ceiling(p/v)) bits (the final copy of the password
// may be truncated to create P). Note that if the password is the
// empty string, then so is P.
P := []byte{}
{
s := len(password)
times := s / v
if s%v > 0 {
times++
}
for len(P) < times*v {
P = append(P, password...)
}
password = nil
P = P[:times*v]
}
// 4. Set I=S||P to be the concatenation of S and P.
I := append(S, P...)
// 5. Set c=ceiling(n/u).
c := size / u
if size%u > 0 {
c++
}
// 6. For i=1, 2, ..., c, do the following:
A := make([]byte, c*20)
for i := 0; i < c; i++ {
// A. Set A2=H^r(D||I). (i.e., the r-th hash of D||1,
// H(H(H(... H(D||I))))
Ai := hash(append(D, I...))
for j := 1; j < r; j++ {
Ai = hash(Ai[:])
}
copy(A[i*20:], Ai[:])
if i < c-1 { // skip on last iteration
// B. Concatenate copies of Ai to create a string B of length v
// bits (the final copy of Ai may be truncated to create B).
B := []byte{}
for len(B) < v {
B = append(B, Ai[:]...)
}
B = B[:v]
// C. Treating I as a concatenation I_0, I_1, ..., I_(k-1) of v-bit
// blocks, where k=ceiling(s/v)+ceiling(p/v), modify I by
// setting I_j=(I_j+B+1) mod 2^v for each j.
{
Bbi := new(big.Int)
Bbi.SetBytes(B)
one := big.NewInt(1)
for j := 0; j < len(I)/v; j++ {
Ij := new(big.Int)
Ij.SetBytes(I[j*v : (j+1)*v])
Ij.Add(Ij, Bbi)
Ij.Add(Ij, one)
Ijb := Ij.Bytes()
if len(Ijb) > v {
Ijb = Ijb[len(Ijb)-v:]
}
copy(I[j*v:(j+1)*v], Ijb)
}
}
}
}
// 7. Concatenate A_1, A_2, ..., A_c together to form a pseudorandom
// bit string, A.
// 8. Use the first n bits of A as the output of this entire process.
A = A[:size]
return A
// If the above process is being used to generate a DES key, the process
// should be used to create 64 random bits, and the key's parity bits
// should be set after the 64 bits have been produced. Similar concerns
// hold for 2-key and 3-key triple-DES keys, for CDMF keys, and for any
// similar keys with parity bits "built into them".
}

View File

@ -0,0 +1,18 @@
package pkcs12
import (
"bytes"
"testing"
)
func TestThatPBKDFWorksCorrectlyForLongKeys(t *testing.T) {
pbkdf := deriveKeyByAlg[pbeWithSHAAnd3KeyTripleDESCBC]
salt := []byte("\xff\xff\xff\xff\xff\xff\xff\xff")
password, _ := bmpString("sesame")
key := pbkdf(salt, password, 2048)
if expected := []byte("\x7c\xd9\xfd\x3e\x2b\x3b\xe7\x69\x1a\x44\xe3\xbe\xf0\xf9\xea\x0f\xb9\xb8\x97\xd4\xe3\x25\xd9\xd1"); bytes.Compare(key, expected) != 0 {
t.Fatalf("expected key '% x', but found '% x'", key, expected)
}
}

View File

@ -0,0 +1,263 @@
// Package pkcs12 provides some implementations of PKCS#12.
//
// This implementation is distilled from https://tools.ietf.org/html/rfc7292 and referenced documents.
// It is intended for decoding P12/PFX-stored certificate+key for use with the crypto/tls package.
package pkcs12
import (
"crypto/rand"
"crypto/x509/pkix"
"encoding/asn1"
"errors"
"io"
)
var (
oidLocalKeyID = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 9, 21}
oidDataContentType = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 7, 1}
localKeyId = []byte{0x01, 0x00, 0x00, 0x00}
)
type pfxPdu struct {
Version int
AuthSafe contentInfo
MacData macData `asn1:"optional"`
}
type contentInfo struct {
ContentType asn1.ObjectIdentifier
Content asn1.RawValue `asn1:"tag:0,explicit,optional"`
}
type encryptedData struct {
Version int
EncryptedContentInfo encryptedContentInfo
}
type encryptedContentInfo struct {
ContentType asn1.ObjectIdentifier
ContentEncryptionAlgorithm pkix.AlgorithmIdentifier
EncryptedContent []byte `asn1:"tag:0,optional"`
}
func (i encryptedContentInfo) GetAlgorithm() pkix.AlgorithmIdentifier {
return i.ContentEncryptionAlgorithm
}
func (i encryptedContentInfo) GetData() []byte { return i.EncryptedContent }
type safeBag struct {
Id asn1.ObjectIdentifier
Value asn1.RawValue `asn1:"tag:0,explicit"`
Attributes []pkcs12Attribute `asn1:"set,optional"`
}
type pkcs12Attribute struct {
Id asn1.ObjectIdentifier
Value asn1.RawValue `ans1:"set"`
}
type encryptedPrivateKeyInfo struct {
AlgorithmIdentifier pkix.AlgorithmIdentifier
EncryptedData []byte
}
func (i encryptedPrivateKeyInfo) GetAlgorithm() pkix.AlgorithmIdentifier { return i.AlgorithmIdentifier }
func (i encryptedPrivateKeyInfo) GetData() []byte { return i.EncryptedData }
// unmarshal calls asn1.Unmarshal, but also returns an error if there is any
// trailing data after unmarshaling.
func unmarshal(in []byte, out interface{}) error {
trailing, err := asn1.Unmarshal(in, out)
if err != nil {
return err
}
if len(trailing) != 0 {
return errors.New("pkcs12: trailing data found")
}
return nil
}
func getLocalKeyId(id []byte) (attribute pkcs12Attribute, err error) {
octetString := asn1.RawValue{Tag: 4, Class: 0, IsCompound: false, Bytes: id}
bytes, err := asn1.Marshal(octetString)
if err != nil {
return
}
attribute = pkcs12Attribute{
Id: oidLocalKeyID,
Value: asn1.RawValue{Tag: 17, Class: 0, IsCompound: true, Bytes: bytes},
}
return attribute, nil
}
func convertToRawVal(val interface{}) (raw asn1.RawValue, err error) {
bytes, err := asn1.Marshal(val)
if err != nil {
return
}
_, err = asn1.Unmarshal(bytes, &raw)
return raw, nil
}
func makeSafeBags(oid asn1.ObjectIdentifier, value []byte) ([]safeBag, error) {
attribute, err := getLocalKeyId(localKeyId)
if err != nil {
return nil, EncodeError("local key id: " + err.Error())
}
bag := make([]safeBag, 1)
bag[0] = safeBag{
Id: oid,
Value: asn1.RawValue{Tag: 0, Class: 2, IsCompound: true, Bytes: value},
Attributes: []pkcs12Attribute{attribute},
}
return bag, nil
}
func makeCertBagContentInfo(derBytes []byte) (*contentInfo, error) {
certBag1 := certBag{
Id: oidCertTypeX509Certificate,
Data: derBytes,
}
bytes, err := asn1.Marshal(certBag1)
if err != nil {
return nil, EncodeError("encoding cert bag: " + err.Error())
}
certSafeBags, err := makeSafeBags(oidCertBagType, bytes)
if err != nil {
return nil, EncodeError("safe bags: " + err.Error())
}
return makeContentInfo(certSafeBags)
}
func makeShroudedKeyBagContentInfo(privateKey interface{}, password []byte) (*contentInfo, error) {
shroudedKeyBagBytes, err := encodePkcs8ShroudedKeyBag(privateKey, password)
if err != nil {
return nil, EncodeError("encode PKCS#8 shrouded key bag: " + err.Error())
}
safeBags, err := makeSafeBags(oidPkcs8ShroudedKeyBagType, shroudedKeyBagBytes)
if err != nil {
return nil, EncodeError("safe bags: " + err.Error())
}
return makeContentInfo(safeBags)
}
func makeContentInfo(val interface{}) (*contentInfo, error) {
fullBytes, err := asn1.Marshal(val)
if err != nil {
return nil, EncodeError("contentInfo raw value marshal: " + err.Error())
}
octetStringVal := asn1.RawValue{Tag: 4, Class: 0, IsCompound: false, Bytes: fullBytes}
octetStringFullBytes, err := asn1.Marshal(octetStringVal)
if err != nil {
return nil, EncodeError("raw contentInfo to octet string: " + err.Error())
}
contentInfo := contentInfo{ContentType: oidDataContentType}
contentInfo.Content = asn1.RawValue{Tag: 0, Class: 2, IsCompound: true, Bytes: octetStringFullBytes}
return &contentInfo, nil
}
func makeContentInfos(derBytes []byte, privateKey interface{}, password []byte) ([]contentInfo, error) {
shroudedKeyContentInfo, err := makeShroudedKeyBagContentInfo(privateKey, password)
if err != nil {
return nil, EncodeError("shrouded key content info: " + err.Error())
}
certBagContentInfo, err := makeCertBagContentInfo(derBytes)
if err != nil {
return nil, EncodeError("cert bag content info: " + err.Error())
}
contentInfos := make([]contentInfo, 2)
contentInfos[0] = *shroudedKeyContentInfo
contentInfos[1] = *certBagContentInfo
return contentInfos, nil
}
func makeSalt(saltByteCount int) ([]byte, error) {
salt := make([]byte, saltByteCount)
_, err := io.ReadFull(rand.Reader, salt)
return salt, err
}
// Encode converts a certificate and a private key to the PKCS#12 byte stream format.
//
// derBytes is a DER encoded certificate.
// privateKey is an RSA
func Encode(derBytes []byte, privateKey interface{}, password string) (pfxBytes []byte, err error) {
secret, err := bmpString(password)
if err != nil {
return nil, ErrIncorrectPassword
}
contentInfos, err := makeContentInfos(derBytes, privateKey, secret)
if err != nil {
return nil, err
}
// Marhsal []contentInfo so we can re-constitute the byte stream that will
// be suitable for computing the MAC
bytes, err := asn1.Marshal(contentInfos)
if err != nil {
return nil, err
}
// Unmarshal as an asn1.RawValue so, we can compute the MAC against the .Bytes
var contentInfosRaw asn1.RawValue
err = unmarshal(bytes, &contentInfosRaw)
if err != nil {
return nil, err
}
authSafeContentInfo, err := makeContentInfo(contentInfosRaw)
if err != nil {
return nil, EncodeError("authSafe content info: " + err.Error())
}
salt, err := makeSalt(pbeSaltSizeBytes)
if err != nil {
return nil, EncodeError("salt value: " + err.Error())
}
// Compute the MAC for marshaled bytes of contentInfos, which includes the
// cert bag, and the shrouded key bag.
digest := computeMac(contentInfosRaw.FullBytes, pbeIterationCount, salt, secret)
pfx := pfxPdu{
Version: 3,
AuthSafe: *authSafeContentInfo,
MacData: macData{
Iterations: pbeIterationCount,
MacSalt: salt,
Mac: digestInfo{
Algorithm: pkix.AlgorithmIdentifier{
Algorithm: oidSha1Algorithm,
},
Digest: digest,
},
},
}
bytes, err = asn1.Marshal(pfx)
if err != nil {
return nil, EncodeError("marshal PFX PDU: " + err.Error())
}
return bytes, err
}

View File

@ -0,0 +1,117 @@
package pkcs12
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"math/big"
"testing"
"time"
gopkcs12 "golang.org/x/crypto/pkcs12"
)
func TestPfxRoundTriRsa(t *testing.T) {
privateKey, err := rsa.GenerateKey(rand.Reader, 512)
if err != nil {
t.Fatal(err.Error())
}
key := testPfxRoundTrip(t, privateKey)
actualPrivateKey, ok := key.(*rsa.PrivateKey)
if !ok {
t.Fatalf("failed to decode private key")
}
if privateKey.D.Cmp(actualPrivateKey.D) != 0 {
t.Errorf("priv.D")
}
}
func TestPfxRoundTriEcdsa(t *testing.T) {
privateKey, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
if err != nil {
t.Fatal(err.Error())
}
key := testPfxRoundTrip(t, privateKey)
actualPrivateKey, ok := key.(*ecdsa.PrivateKey)
if !ok {
t.Fatalf("failed to decode private key")
}
if privateKey.D.Cmp(actualPrivateKey.D) != 0 {
t.Errorf("priv.D")
}
}
func testPfxRoundTrip(t *testing.T, privateKey interface{}) interface{} {
certificateBytes, err := newCertificate("hostname", privateKey)
if err != nil {
t.Fatal(err.Error())
}
bytes, err := Encode(certificateBytes, privateKey, "sesame")
if err != nil {
t.Fatal(err.Error())
}
key, _, err := gopkcs12.Decode(bytes, "sesame")
if err != nil {
t.Fatalf(err.Error())
}
return key
}
func newCertificate(hostname string, privateKey interface{}) ([]byte, error) {
t, _ := time.Parse("2006-01-02", "2016-01-01")
notBefore := t
notAfter := notBefore.Add(365 * 24 * time.Hour)
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
err := fmt.Errorf("Failed to Generate Serial Number: %v", err)
return nil, err
}
template := x509.Certificate{
SerialNumber: serialNumber,
Issuer: pkix.Name{
CommonName: hostname,
},
Subject: pkix.Name{
CommonName: hostname,
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
var publicKey interface{}
switch key := privateKey.(type) {
case *rsa.PrivateKey:
publicKey = key.Public()
case *ecdsa.PrivateKey:
publicKey = key.Public()
default:
panic(fmt.Sprintf("unsupported private key type: %T", privateKey))
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey, privateKey)
if err != nil {
return nil, fmt.Errorf("Failed to Generate derBytes: " + err.Error())
}
return derBytes, nil
}

View File

@ -0,0 +1,64 @@
package pkcs12
import (
"crypto/ecdsa"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"errors"
)
// pkcs8 reflects an ASN.1, PKCS#8 PrivateKey. See
// ftp://ftp.rsasecurity.com/pub/pkcs/pkcs-8/pkcs-8v1_2.asn
// and RFC5208.
type pkcs8 struct {
Version int
Algo pkix.AlgorithmIdentifier
PrivateKey []byte
// optional attributes omitted.
}
var (
oidPublicKeyRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1}
oidPublicKeyECDSA = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1}
nullAsn = asn1.RawValue{Tag: 5}
)
// marshalPKCS8PrivateKey converts a private key to PKCS#8 encoded form.
// See http://www.rsa.com/rsalabs/node.asp?id=2130 and RFC5208.
func marshalPKCS8PrivateKey(key interface{}) (der []byte, err error) {
pkcs := pkcs8{
Version: 0,
}
switch key := key.(type) {
case *rsa.PrivateKey:
pkcs.Algo = pkix.AlgorithmIdentifier{
Algorithm: oidPublicKeyRSA,
Parameters: nullAsn,
}
pkcs.PrivateKey = x509.MarshalPKCS1PrivateKey(key)
case *ecdsa.PrivateKey:
bytes, err := x509.MarshalECPrivateKey(key)
if err != nil {
return nil, errors.New("x509: failed to marshal to PKCS#8: " + err.Error())
}
pkcs.Algo = pkix.AlgorithmIdentifier{
Algorithm: oidPublicKeyECDSA,
Parameters: nullAsn,
}
pkcs.PrivateKey = bytes
default:
return nil, errors.New("x509: PKCS#8 only RSA and ECDSA private keys supported")
}
bytes, err := asn1.Marshal(pkcs)
if err != nil {
return nil, errors.New("x509: failed to marshal to PKCS#8: " + err.Error())
}
return bytes, nil
}

View File

@ -0,0 +1,141 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pkcs12
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/asn1"
"testing"
)
func TestRoundTripPkcs8Rsa(t *testing.T) {
privateKey, err := rsa.GenerateKey(rand.Reader, 512)
if err != nil {
t.Fatalf("failed to generate a private key: %s", err)
}
bytes, err := marshalPKCS8PrivateKey(privateKey)
if err != nil {
t.Fatalf("failed to marshal private key: %s", err)
}
key, err := x509.ParsePKCS8PrivateKey(bytes)
if err != nil {
t.Fatalf("failed to parse private key: %s", err)
}
actualPrivateKey, ok := key.(*rsa.PrivateKey)
if !ok {
t.Fatalf("expected key to be of type *rsa.PrivateKey, but actual was %T", key)
}
if actualPrivateKey.Validate() != nil {
t.Fatalf("private key did not validate")
}
if actualPrivateKey.N.Cmp(privateKey.N) != 0 {
t.Errorf("private key's N did not round trip")
}
if actualPrivateKey.D.Cmp(privateKey.D) != 0 {
t.Errorf("private key's D did not round trip")
}
if actualPrivateKey.E != privateKey.E {
t.Errorf("private key's E did not round trip")
}
if actualPrivateKey.Primes[0].Cmp(privateKey.Primes[0]) != 0 {
t.Errorf("private key's P did not round trip")
}
if actualPrivateKey.Primes[1].Cmp(privateKey.Primes[1]) != 0 {
t.Errorf("private key's Q did not round trip")
}
}
func TestRoundTripPkcs8Ecdsa(t *testing.T) {
privateKey, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
if err != nil {
t.Fatalf("failed to generate a private key: %s", err)
}
bytes, err := marshalPKCS8PrivateKey(privateKey)
if err != nil {
t.Fatalf("failed to marshal private key: %s", err)
}
key, err := x509.ParsePKCS8PrivateKey(bytes)
if err != nil {
t.Fatalf("failed to parse private key: %s", err)
}
actualPrivateKey, ok := key.(*ecdsa.PrivateKey)
if !ok {
t.Fatalf("expected key to be of type *ecdsa.PrivateKey, but actual was %T", key)
}
// sanity check, not exhaustive
if actualPrivateKey.D.Cmp(privateKey.D) != 0 {
t.Errorf("private key's D did not round trip")
}
if actualPrivateKey.X.Cmp(privateKey.X) != 0 {
t.Errorf("private key's X did not round trip")
}
if actualPrivateKey.Y.Cmp(privateKey.Y) != 0 {
t.Errorf("private key's Y did not round trip")
}
if actualPrivateKey.Curve.Params().B.Cmp(privateKey.Curve.Params().B) != 0 {
t.Errorf("private key's Curve.B did not round trip")
}
}
func TestNullParametersPkcs8Rsa(t *testing.T) {
privateKey, err := rsa.GenerateKey(rand.Reader, 512)
if err != nil {
t.Fatalf("failed to generate a private key: %s", err)
}
checkNullParameter(t, privateKey)
}
func TestNullParametersPkcs8Ecdsa(t *testing.T) {
privateKey, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
if err != nil {
t.Fatalf("failed to generate a private key: %s", err)
}
checkNullParameter(t, privateKey)
}
func checkNullParameter(t *testing.T, privateKey interface{}) {
bytes, err := marshalPKCS8PrivateKey(privateKey)
if err != nil {
t.Fatalf("failed to marshal private key: %s", err)
}
var pkcs pkcs8
rest, err := asn1.Unmarshal(bytes, &pkcs)
if err != nil {
t.Fatalf("failed to unmarshal PKCS#8: %s", err)
}
if len(rest) != 0 {
t.Fatalf("unexpected trailing bytes of len=%d, bytes=%x", len(rest), rest)
}
// Only version == 0 is known and valid
if pkcs.Version != 0 {
t.Errorf("expected version=0, but actual=%d", pkcs.Version)
}
// ensure a NULL parameter is inserted
if pkcs.Algo.Parameters.Tag != 5 {
t.Errorf("expected parameters to be NULL, but actual tag=%d, class=%d, isCompound=%t, bytes=%x",
pkcs.Algo.Parameters.Tag,
pkcs.Algo.Parameters.Class,
pkcs.Algo.Parameters.IsCompound,
pkcs.Algo.Parameters.Bytes)
}
}

View File

@ -0,0 +1,284 @@
// Package rc2 implements the RC2 cipher
/*
https://www.ietf.org/rfc/rfc2268.txt
http://people.csail.mit.edu/rivest/pubs/KRRR98.pdf
This code is licensed under the MIT license.
*/
package rc2
import (
"crypto/cipher"
"encoding/binary"
"strconv"
)
// The rc2 block size in bytes
const BlockSize = 8
type rc2Cipher struct {
k [64]uint16
}
// KeySizeError indicates the supplied key was invalid
type KeySizeError int
func (k KeySizeError) Error() string { return "rc2: invalid key size " + strconv.Itoa(int(k)) }
// EffectiveKeySizeError indicates the supplied effective key length was invalid
type EffectiveKeySizeError int
func (k EffectiveKeySizeError) Error() string {
return "rc2: invalid effective key size " + strconv.Itoa(int(k))
}
// New returns a new rc2 cipher with the given key and effective key length t1
func New(key []byte, t1 int) (cipher.Block, error) {
if l := len(key); l == 0 || l > 128 {
return nil, KeySizeError(l)
}
if t1 < 8 || t1 > 1024 {
return nil, EffectiveKeySizeError(t1)
}
return &rc2Cipher{
k: expandKey(key, t1),
}, nil
}
func (c *rc2Cipher) BlockSize() int { return BlockSize }
var piTable = [256]byte{
0xd9, 0x78, 0xf9, 0xc4, 0x19, 0xdd, 0xb5, 0xed, 0x28, 0xe9, 0xfd, 0x79, 0x4a, 0xa0, 0xd8, 0x9d,
0xc6, 0x7e, 0x37, 0x83, 0x2b, 0x76, 0x53, 0x8e, 0x62, 0x4c, 0x64, 0x88, 0x44, 0x8b, 0xfb, 0xa2,
0x17, 0x9a, 0x59, 0xf5, 0x87, 0xb3, 0x4f, 0x13, 0x61, 0x45, 0x6d, 0x8d, 0x09, 0x81, 0x7d, 0x32,
0xbd, 0x8f, 0x40, 0xeb, 0x86, 0xb7, 0x7b, 0x0b, 0xf0, 0x95, 0x21, 0x22, 0x5c, 0x6b, 0x4e, 0x82,
0x54, 0xd6, 0x65, 0x93, 0xce, 0x60, 0xb2, 0x1c, 0x73, 0x56, 0xc0, 0x14, 0xa7, 0x8c, 0xf1, 0xdc,
0x12, 0x75, 0xca, 0x1f, 0x3b, 0xbe, 0xe4, 0xd1, 0x42, 0x3d, 0xd4, 0x30, 0xa3, 0x3c, 0xb6, 0x26,
0x6f, 0xbf, 0x0e, 0xda, 0x46, 0x69, 0x07, 0x57, 0x27, 0xf2, 0x1d, 0x9b, 0xbc, 0x94, 0x43, 0x03,
0xf8, 0x11, 0xc7, 0xf6, 0x90, 0xef, 0x3e, 0xe7, 0x06, 0xc3, 0xd5, 0x2f, 0xc8, 0x66, 0x1e, 0xd7,
0x08, 0xe8, 0xea, 0xde, 0x80, 0x52, 0xee, 0xf7, 0x84, 0xaa, 0x72, 0xac, 0x35, 0x4d, 0x6a, 0x2a,
0x96, 0x1a, 0xd2, 0x71, 0x5a, 0x15, 0x49, 0x74, 0x4b, 0x9f, 0xd0, 0x5e, 0x04, 0x18, 0xa4, 0xec,
0xc2, 0xe0, 0x41, 0x6e, 0x0f, 0x51, 0xcb, 0xcc, 0x24, 0x91, 0xaf, 0x50, 0xa1, 0xf4, 0x70, 0x39,
0x99, 0x7c, 0x3a, 0x85, 0x23, 0xb8, 0xb4, 0x7a, 0xfc, 0x02, 0x36, 0x5b, 0x25, 0x55, 0x97, 0x31,
0x2d, 0x5d, 0xfa, 0x98, 0xe3, 0x8a, 0x92, 0xae, 0x05, 0xdf, 0x29, 0x10, 0x67, 0x6c, 0xba, 0xc9,
0xd3, 0x00, 0xe6, 0xcf, 0xe1, 0x9e, 0xa8, 0x2c, 0x63, 0x16, 0x01, 0x3f, 0x58, 0xe2, 0x89, 0xa9,
0x0d, 0x38, 0x34, 0x1b, 0xab, 0x33, 0xff, 0xb0, 0xbb, 0x48, 0x0c, 0x5f, 0xb9, 0xb1, 0xcd, 0x2e,
0xc5, 0xf3, 0xdb, 0x47, 0xe5, 0xa5, 0x9c, 0x77, 0x0a, 0xa6, 0x20, 0x68, 0xfe, 0x7f, 0xc1, 0xad,
}
func expandKey(key []byte, t1 int) [64]uint16 {
l := make([]byte, 128)
copy(l, key)
var t = len(key)
var t8 = (t1 + 7) / 8
var tm = byte(255 % uint(1<<(8+uint(t1)-8*uint(t8))))
for i := len(key); i < 128; i++ {
l[i] = piTable[l[i-1]+l[uint8(i-t)]]
}
l[128-t8] = piTable[l[128-t8]&tm]
for i := 127 - t8; i >= 0; i-- {
l[i] = piTable[l[i+1]^l[i+t8]]
}
var k [64]uint16
for i := range k {
k[i] = uint16(l[2*i]) + uint16(l[2*i+1])*256
}
return k
}
func rotl16(x uint16, b uint) uint16 {
return (x >> (16 - b)) | (x << b)
}
func (c *rc2Cipher) Encrypt(dst, src []byte) {
r0 := binary.LittleEndian.Uint16(src[0:])
r1 := binary.LittleEndian.Uint16(src[2:])
r2 := binary.LittleEndian.Uint16(src[4:])
r3 := binary.LittleEndian.Uint16(src[6:])
var j int
// These three mix blocks have not been extracted to a common function for to performance reasons.
for j <= 16 {
// mix r0
r0 = r0 + c.k[j] + (r3 & r2) + ((^r3) & r1)
r0 = rotl16(r0, 1)
j++
// mix r1
r1 = r1 + c.k[j] + (r0 & r3) + ((^r0) & r2)
r1 = rotl16(r1, 2)
j++
// mix r2
r2 = r2 + c.k[j] + (r1 & r0) + ((^r1) & r3)
r2 = rotl16(r2, 3)
j++
// mix r3
r3 = r3 + c.k[j] + (r2 & r1) + ((^r2) & r0)
r3 = rotl16(r3, 5)
j++
}
r0 = r0 + c.k[r3&63]
r1 = r1 + c.k[r0&63]
r2 = r2 + c.k[r1&63]
r3 = r3 + c.k[r2&63]
for j <= 40 {
// mix r0
r0 = r0 + c.k[j] + (r3 & r2) + ((^r3) & r1)
r0 = rotl16(r0, 1)
j++
// mix r1
r1 = r1 + c.k[j] + (r0 & r3) + ((^r0) & r2)
r1 = rotl16(r1, 2)
j++
// mix r2
r2 = r2 + c.k[j] + (r1 & r0) + ((^r1) & r3)
r2 = rotl16(r2, 3)
j++
// mix r3
r3 = r3 + c.k[j] + (r2 & r1) + ((^r2) & r0)
r3 = rotl16(r3, 5)
j++
}
r0 = r0 + c.k[r3&63]
r1 = r1 + c.k[r0&63]
r2 = r2 + c.k[r1&63]
r3 = r3 + c.k[r2&63]
for j <= 60 {
// mix r0
r0 = r0 + c.k[j] + (r3 & r2) + ((^r3) & r1)
r0 = rotl16(r0, 1)
j++
// mix r1
r1 = r1 + c.k[j] + (r0 & r3) + ((^r0) & r2)
r1 = rotl16(r1, 2)
j++
// mix r2
r2 = r2 + c.k[j] + (r1 & r0) + ((^r1) & r3)
r2 = rotl16(r2, 3)
j++
// mix r3
r3 = r3 + c.k[j] + (r2 & r1) + ((^r2) & r0)
r3 = rotl16(r3, 5)
j++
}
binary.LittleEndian.PutUint16(dst[0:], r0)
binary.LittleEndian.PutUint16(dst[2:], r1)
binary.LittleEndian.PutUint16(dst[4:], r2)
binary.LittleEndian.PutUint16(dst[6:], r3)
}
func (c *rc2Cipher) Decrypt(dst, src []byte) {
r0 := binary.LittleEndian.Uint16(src[0:])
r1 := binary.LittleEndian.Uint16(src[2:])
r2 := binary.LittleEndian.Uint16(src[4:])
r3 := binary.LittleEndian.Uint16(src[6:])
j := 63
for j >= 44 {
// unmix r3
r3 = rotl16(r3, 16-5)
r3 = r3 - c.k[j] - (r2 & r1) - ((^r2) & r0)
j--
// unmix r2
r2 = rotl16(r2, 16-3)
r2 = r2 - c.k[j] - (r1 & r0) - ((^r1) & r3)
j--
// unmix r1
r1 = rotl16(r1, 16-2)
r1 = r1 - c.k[j] - (r0 & r3) - ((^r0) & r2)
j--
// unmix r0
r0 = rotl16(r0, 16-1)
r0 = r0 - c.k[j] - (r3 & r2) - ((^r3) & r1)
j--
}
r3 = r3 - c.k[r2&63]
r2 = r2 - c.k[r1&63]
r1 = r1 - c.k[r0&63]
r0 = r0 - c.k[r3&63]
for j >= 20 {
// unmix r3
r3 = rotl16(r3, 16-5)
r3 = r3 - c.k[j] - (r2 & r1) - ((^r2) & r0)
j--
// unmix r2
r2 = rotl16(r2, 16-3)
r2 = r2 - c.k[j] - (r1 & r0) - ((^r1) & r3)
j--
// unmix r1
r1 = rotl16(r1, 16-2)
r1 = r1 - c.k[j] - (r0 & r3) - ((^r0) & r2)
j--
// unmix r0
r0 = rotl16(r0, 16-1)
r0 = r0 - c.k[j] - (r3 & r2) - ((^r3) & r1)
j--
}
r3 = r3 - c.k[r2&63]
r2 = r2 - c.k[r1&63]
r1 = r1 - c.k[r0&63]
r0 = r0 - c.k[r3&63]
for j >= 0 {
// unmix r3
r3 = rotl16(r3, 16-5)
r3 = r3 - c.k[j] - (r2 & r1) - ((^r2) & r0)
j--
// unmix r2
r2 = rotl16(r2, 16-3)
r2 = r2 - c.k[j] - (r1 & r0) - ((^r1) & r3)
j--
// unmix r1
r1 = rotl16(r1, 16-2)
r1 = r1 - c.k[j] - (r0 & r3) - ((^r0) & r2)
j--
// unmix r0
r0 = rotl16(r0, 16-1)
r0 = r0 - c.k[j] - (r3 & r2) - ((^r3) & r1)
j--
}
binary.LittleEndian.PutUint16(dst[0:], r0)
binary.LittleEndian.PutUint16(dst[2:], r1)
binary.LittleEndian.PutUint16(dst[4:], r2)
binary.LittleEndian.PutUint16(dst[6:], r3)
}

View File

@ -0,0 +1,105 @@
package rc2
import (
"bytes"
"encoding/hex"
"testing"
)
func TestEncryptDecrypt(t *testing.T) {
var tests = []struct {
key string
plain string
cipher string
t1 int
}{
{
"0000000000000000",
"0000000000000000",
"ebb773f993278eff",
63,
},
{
"ffffffffffffffff",
"ffffffffffffffff",
"278b27e42e2f0d49",
64,
},
{
"3000000000000000",
"1000000000000001",
"30649edf9be7d2c2",
64,
},
{
"88",
"0000000000000000",
"61a8a244adacccf0",
64,
},
{
"88bca90e90875a",
"0000000000000000",
"6ccf4308974c267f",
64,
},
{
"88bca90e90875a7f0f79c384627bafb2",
"0000000000000000",
"1a807d272bbe5db1",
64,
},
{
"88bca90e90875a7f0f79c384627bafb2",
"0000000000000000",
"2269552ab0f85ca6",
128,
},
{
"88bca90e90875a7f0f79c384627bafb216f80a6f85920584c42fceb0be255daf1e",
"0000000000000000",
"5b78d3a43dfff1f1",
129,
},
}
for _, tt := range tests {
k, _ := hex.DecodeString(tt.key)
p, _ := hex.DecodeString(tt.plain)
c, _ := hex.DecodeString(tt.cipher)
b, _ := New(k, tt.t1)
var dst [8]byte
b.Encrypt(dst[:], p)
if !bytes.Equal(dst[:], c) {
t.Errorf("encrypt failed: got % 2x wanted % 2x\n", dst, c)
}
b.Decrypt(dst[:], c)
if !bytes.Equal(dst[:], p) {
t.Errorf("decrypt failed: got % 2x wanted % 2x\n", dst, p)
}
}
}
func BenchmarkEncrypt(b *testing.B) {
r, _ := New([]byte{0, 0, 0, 0, 0, 0, 0, 0}, 64)
b.ResetTimer()
var src [8]byte
for i := 0; i < b.N; i++ {
r.Encrypt(src[:], src[:])
}
}
func BenchmarkDecrypt(b *testing.B) {
r, _ := New([]byte{0, 0, 0, 0, 0, 0, 0, 0}, 64)
b.ResetTimer()
var src [8]byte
for i := 0; i < b.N; i++ {
r.Decrypt(src[:], src[:])
}
}

View File

@ -0,0 +1,67 @@
package pkcs12
import (
"crypto/x509/pkix"
"encoding/asn1"
"errors"
)
//see https://tools.ietf.org/html/rfc7292#appendix-D
var (
oidPkcs8ShroudedKeyBagType = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 12, 10, 1, 2}
oidCertBagType = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 12, 10, 1, 3}
oidCertTypeX509Certificate = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 9, 22, 1}
)
type certBag struct {
Id asn1.ObjectIdentifier
Data []byte `asn1:"tag:0,explicit"`
}
func getAlgorithmParams(salt []byte, iterations int) (asn1.RawValue, error) {
params := pbeParams{
Salt: salt,
Iterations: iterations,
}
return convertToRawVal(params)
}
func encodePkcs8ShroudedKeyBag(privateKey interface{}, password []byte) (bytes []byte, err error) {
privateKeyBytes, err := marshalPKCS8PrivateKey(privateKey)
if err != nil {
return nil, errors.New("pkcs12: error encoding PKCS#8 private key: " + err.Error())
}
salt, err := makeSalt(pbeSaltSizeBytes)
if err != nil {
return nil, errors.New("pkcs12: error creating PKCS#8 salt: " + err.Error())
}
pkData, err := pbEncrypt(privateKeyBytes, salt, password, pbeIterationCount)
if err != nil {
return nil, errors.New("pkcs12: error encoding PKCS#8 shrouded key bag when encrypting cert bag: " + err.Error())
}
params, err := getAlgorithmParams(salt, pbeIterationCount)
if err != nil {
return nil, errors.New("pkcs12: error encoding PKCS#8 shrouded key bag algorithm's parameters: " + err.Error())
}
pkinfo := encryptedPrivateKeyInfo{
AlgorithmIdentifier: pkix.AlgorithmIdentifier{
Algorithm: oidPbeWithSHAAnd3KeyTripleDESCBC,
Parameters: params,
},
EncryptedData: pkData,
}
bytes, err = asn1.Marshal(pkinfo)
if err != nil {
return nil, errors.New("pkcs12: error encoding PKCS#8 shrouded key bag: " + err.Error())
}
return bytes, err
}

View File

@ -0,0 +1,102 @@
package pkcs12
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/asn1"
"fmt"
"testing"
)
func decodePkcs8ShroudedKeyBag(asn1Data, password []byte) (privateKey interface{}, err error) {
pkinfo := new(encryptedPrivateKeyInfo)
if _, err = asn1.Unmarshal(asn1Data, pkinfo); err != nil {
err = fmt.Errorf("error decoding PKCS8 shrouded key bag: %v", err)
return nil, err
}
pkData, err := pbDecrypt(pkinfo, password)
if err != nil {
err = fmt.Errorf("error decrypting PKCS8 shrouded key bag: %v", err)
return
}
rv := new(asn1.RawValue)
if _, err = asn1.Unmarshal(pkData, rv); err != nil {
err = fmt.Errorf("could not decode decrypted private key data")
}
if privateKey, err = x509.ParsePKCS8PrivateKey(pkData); err != nil {
err = fmt.Errorf("error parsing PKCS8 private key: %v", err)
return nil, err
}
return
}
// Assert the default algorithm parameters are in the correct order,
// and default to the correct value. Defaults are based on OpenSSL.
// 1. IterationCount, defaults to 2,048 long.
// 2. Salt, is 8 bytes long.
func TestDefaultAlgorithmParametersPkcs8ShroudedKeyBag(t *testing.T) {
privateKey, err := rsa.GenerateKey(rand.Reader, 512)
if err != nil {
t.Fatalf("failed to generate a private key: %s", err)
}
password := []byte("sesame")
bytes, err := encodePkcs8ShroudedKeyBag(privateKey, password)
if err != nil {
t.Fatalf("failed to encode PKCS#8 shrouded key bag: %s", err)
}
var pkinfo encryptedPrivateKeyInfo
rest, err := asn1.Unmarshal(bytes, &pkinfo)
if err != nil {
t.Fatalf("failed to unmarshal encryptedPrivateKeyInfo %s", err)
}
if len(rest) != 0 {
t.Fatalf("unexpected trailing bytes of len=%d, bytes=%x", len(rest), rest)
}
var params pbeParams
rest, err = asn1.Unmarshal(pkinfo.GetAlgorithm().Parameters.FullBytes, &params)
if err != nil {
t.Fatalf("failed to unmarshal encryptedPrivateKeyInfo %s", err)
}
if len(rest) != 0 {
t.Fatalf("unexpected trailing bytes of len=%d, bytes=%x", len(rest), rest)
}
if params.Iterations != pbeIterationCount {
t.Errorf("expected iteration count to be %d, but actual=%d", pbeIterationCount, params.Iterations)
}
if len(params.Salt) != pbeSaltSizeBytes {
t.Errorf("expected the number of salt bytes to be %d, but actual=%d", pbeSaltSizeBytes, len(params.Salt))
}
}
func TestRoundTripPkcs8ShroudedKeyBag(t *testing.T) {
privateKey, err := rsa.GenerateKey(rand.Reader, 512)
if err != nil {
t.Fatalf("failed to generate a private key: %s", err)
}
password := []byte("sesame")
bytes, err := encodePkcs8ShroudedKeyBag(privateKey, password)
if err != nil {
t.Fatalf("failed to encode PKCS#8 shrouded key bag: %s", err)
}
key, err := decodePkcs8ShroudedKeyBag(bytes, password)
if err != nil {
t.Fatalf("failed to decode PKCS#8 shrouded key bag: %s", err)
}
actualPrivateKey := key.(*rsa.PrivateKey)
if actualPrivateKey.D.Cmp(privateKey.D) != 0 {
t.Fatalf("failed to round-trip rsa.PrivateKey.D")
}
}