Merge remote-tracking branch 'upstream/master' into packer-builder-profitbricks

This commit is contained in:
jasminSPC 2016-09-13 12:06:45 +02:00
commit 325401eaf0
204 changed files with 29333 additions and 4118 deletions

View File

@ -5,6 +5,7 @@ BACKWARDS INCOMPATIBILITIES:
* VNC and VRDP-like features in VirtualBox, VMware, and QEMU now configurable * VNC and VRDP-like features in VirtualBox, VMware, and QEMU now configurable
but bind to 127.0.0.1 by default to improve security. See the relevant but bind to 127.0.0.1 by default to improve security. See the relevant
builder docs for more info. builder docs for more info.
* Docker builder requires Docker > 1.3
FEATURES: FEATURES:
@ -12,12 +13,15 @@ FEATURES:
IMPROVEMENTS: IMPROVEMENTS:
* core: Test floppy disk files actually exist [GH-3756]
* builder/amazon: Added `disable_stop_instance` option to prevent automatic * builder/amazon: Added `disable_stop_instance` option to prevent automatic
shutdown when the build is complete [GH-3352] shutdown when the build is complete [GH-3352]
* builder/amazon: Added `skip_region_validation` option to allow newer or * builder/amazon: Added `skip_region_validation` option to allow newer or
custom AWS regions [GH-3598] custom AWS regions [GH-3598]
* builder/amazon: Added `shutdown_behavior` option to support `stop` or * builder/amazon: Added `shutdown_behavior` option to support `stop` or
`terminate` at the end of the build [GH-3556] `terminate` at the end of the build [GH-3556]
* builder/amazon: Support building from scratch with amazon-chroot builder.
[GH-3855]
* builder/azure: Now pre-validates `capture_container_name` and * builder/azure: Now pre-validates `capture_container_name` and
`capture_name_prefix` [GH-3537] `capture_name_prefix` [GH-3537]
* builder/azure: Support for custom images [GH-3575] * builder/azure: Support for custom images [GH-3575]
@ -25,8 +29,12 @@ IMPROVEMENTS:
* builder/azure: Made `tenant_id` optional [GH-3643] * builder/azure: Made `tenant_id` optional [GH-3643]
* builder/digitalocean: Use `state_timeout` for unlock and off transitions. * builder/digitalocean: Use `state_timeout` for unlock and off transitions.
[GH-3444] [GH-3444]
* builder/digitalocean: Fixes timeout waiting for snapshot [GH-3868]
* builder/docker: Improved support for Docker pull from Amazon ECR. [GH-3856]
* builder/google: Added support for `image_family` [GH-3503] * builder/google: Added support for `image_family` [GH-3503]
* builder/google: Use gcloud application default credentials. [GH-3655] * builder/google: Use gcloud application default credentials. [GH-3655]
* builder/google: Signal that startup script fished via metadata. [GH-3873]
* builder/google: Add image license metadata. [GH-3873]
* builder/null: Can now be used with WinRM [GH-2525] * builder/null: Can now be used with WinRM [GH-2525]
* builder/parallels: Now pauses between `boot_command` entries when running * builder/parallels: Now pauses between `boot_command` entries when running
with `-debug` [GH-3547] with `-debug` [GH-3547]
@ -45,11 +53,15 @@ IMPROVEMENTS:
* builder/qemu: Now pauses between `boot_command` entries when running with * builder/qemu: Now pauses between `boot_command` entries when running with
`-debug` [GH-3547] `-debug` [GH-3547]
* provisioner/ansible: Improved logging and error handling [GH-3477] * provisioner/ansible: Improved logging and error handling [GH-3477]
* provisioner/ansible: Support scp [GH-3861]
* provisioner/ansible-local: Support for ansible-galaxy [GH-3350] [GH-3836]
* provisioner/chef: Added `knife_command` option and added a correct default * provisioner/chef: Added `knife_command` option and added a correct default
value for Windows [GH-3622] value for Windows [GH-3622]
* provisioner/chef: Installs 64bit chef on Windows if available [GH-3848]
* provisioner/puppet: Added `execute_command` option [GH-3614] * provisioner/puppet: Added `execute_command` option [GH-3614]
* post-processor/compress: Added support for bgzf compression [GH-3501] * post-processor/compress: Added support for bgzf compression [GH-3501]
* post-processor/docker: Preserve tags when running docker push [GH-3631] * post-processor/docker: Preserve tags when running docker push [GH-3631]
* post-processor/docker: Improved support for Docker push to Amazon ECR [GH-3856]
* scripts: Added `help` target to Makefile [GH-3290] * scripts: Added `help` target to Makefile [GH-3290]
BUG FIXES: BUG FIXES:
@ -60,8 +72,10 @@ BUG FIXES:
is specified [GH-3568] is specified [GH-3568]
* builder/amazon: Use `temporary_key_pair_name` when specified. [GH-3739] * builder/amazon: Use `temporary_key_pair_name` when specified. [GH-3739]
* builder/amazon: Add 0.5 cents to discovered spot price. [GH-3662] * builder/amazon: Add 0.5 cents to discovered spot price. [GH-3662]
* builder/amazon: Fix packer crash when waiting for SSH. [GH-3865]
* builder/azure: check for empty resource group [GH-3606] * builder/azure: check for empty resource group [GH-3606]
* builder/azure: fix token validity test [GH-3609] * builder/azure: fix token validity test [GH-3609]
* builder/docker: fix docker builder with ansible provisioner. [GH-3476]
* builder/virtualbox: Respect `ssh_host` [GH-3617] * builder/virtualbox: Respect `ssh_host` [GH-3617]
* builder/vmware: Re-introduce case sensitive VMX keys [GH-2707] * builder/vmware: Re-introduce case sensitive VMX keys [GH-2707]
* builder/vmware: Don't check for poweron errors on ESXi [GH-3195] * builder/vmware: Don't check for poweron errors on ESXi [GH-3195]

155
Godeps/Godeps.json generated
View File

@ -1,5 +1,5 @@
{ {
"ImportPath": "github.com/StackPointCloud/packer", "ImportPath": "github.com/mitchellh/packer",
"GoVersion": "go1.6", "GoVersion": "go1.6",
"GodepVersion": "v74", "GodepVersion": "v74",
"Deps": [ "Deps": [
@ -80,138 +80,163 @@
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/aws", "ImportPath": "github.com/aws/aws-sdk-go/aws",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/aws/awserr", "ImportPath": "github.com/aws/aws-sdk-go/aws/awserr",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/aws/awsutil", "ImportPath": "github.com/aws/aws-sdk-go/aws/awsutil",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/aws/client", "ImportPath": "github.com/aws/aws-sdk-go/aws/client",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/aws/client/metadata", "ImportPath": "github.com/aws/aws-sdk-go/aws/client/metadata",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/aws/corehandlers", "ImportPath": "github.com/aws/aws-sdk-go/aws/corehandlers",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/aws/credentials", "ImportPath": "github.com/aws/aws-sdk-go/aws/credentials",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds", "ImportPath": "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/credentials/endpointcreds",
"Comment": "v1.4.6",
"Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/credentials/stscreds",
"Comment": "v1.4.6",
"Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/aws/defaults", "ImportPath": "github.com/aws/aws-sdk-go/aws/defaults",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/aws/ec2metadata", "ImportPath": "github.com/aws/aws-sdk-go/aws/ec2metadata",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/aws/request", "ImportPath": "github.com/aws/aws-sdk-go/aws/request",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/aws/session", "ImportPath": "github.com/aws/aws-sdk-go/aws/session",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/signer/v4",
"Comment": "v1.4.6",
"Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/private/endpoints", "ImportPath": "github.com/aws/aws-sdk-go/private/endpoints",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol", "ImportPath": "github.com/aws/aws-sdk-go/private/protocol",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/ec2query", "ImportPath": "github.com/aws/aws-sdk-go/private/protocol/ec2query",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/json/jsonutil",
"Comment": "v1.4.6",
"Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/jsonrpc",
"Comment": "v1.4.6",
"Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/query", "ImportPath": "github.com/aws/aws-sdk-go/private/protocol/query",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/query/queryutil", "ImportPath": "github.com/aws/aws-sdk-go/private/protocol/query/queryutil",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/rest", "ImportPath": "github.com/aws/aws-sdk-go/private/protocol/rest",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/restxml", "ImportPath": "github.com/aws/aws-sdk-go/private/protocol/restxml",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil", "ImportPath": "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/private/signer/v4",
"Comment": "v1.1.2",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/private/waiter", "ImportPath": "github.com/aws/aws-sdk-go/private/waiter",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/service/ec2", "ImportPath": "github.com/aws/aws-sdk-go/service/ec2",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/service/ecr",
"Comment": "v1.4.6",
"Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/service/s3", "ImportPath": "github.com/aws/aws-sdk-go/service/s3",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/service/s3/s3iface", "ImportPath": "github.com/aws/aws-sdk-go/service/s3/s3iface",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/service/s3/s3manager", "ImportPath": "github.com/aws/aws-sdk-go/service/s3/s3manager",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/service/sts", "ImportPath": "github.com/aws/aws-sdk-go/service/sts",
"Comment": "v1.1.2", "Comment": "v1.4.6",
"Rev": "8041be5461786460d86b4358305fbdf32d37cfb2" "Rev": "6ac30507cca29249f4d49af45a8efc98b84088ee"
}, },
{ {
"ImportPath": "github.com/bgentry/speakeasy", "ImportPath": "github.com/bgentry/speakeasy",
@ -221,6 +246,10 @@
"ImportPath": "github.com/biogo/hts/bgzf", "ImportPath": "github.com/biogo/hts/bgzf",
"Rev": "50da7d4131a3b5c9d063932461cab4d1fafb20b0" "Rev": "50da7d4131a3b5c9d063932461cab4d1fafb20b0"
}, },
{
"ImportPath": "github.com/davecgh/go-spew/spew",
"Rev": "5215b55f46b2b919f50a1df0eaa5886afe4e3b3d"
},
{ {
"ImportPath": "github.com/dgrijalva/jwt-go", "ImportPath": "github.com/dgrijalva/jwt-go",
"Comment": "v3.0.0", "Comment": "v3.0.0",
@ -427,6 +456,11 @@
"ImportPath": "github.com/pkg/sftp", "ImportPath": "github.com/pkg/sftp",
"Rev": "e84cc8c755ca39b7b64f510fe1fffc1b51f210a5" "Rev": "e84cc8c755ca39b7b64f510fe1fffc1b51f210a5"
}, },
{
"ImportPath": "github.com/pmezard/go-difflib/difflib",
"Comment": "v1.0.0",
"Rev": "792786c7400a136282c1664665ae0a8db921c6c2"
},
{ {
"ImportPath": "github.com/rackspace/gophercloud", "ImportPath": "github.com/rackspace/gophercloud",
"Comment": "v1.0.0-810-g53d1dc4", "Comment": "v1.0.0-810-g53d1dc4",
@ -516,6 +550,11 @@
"ImportPath": "github.com/satori/go.uuid", "ImportPath": "github.com/satori/go.uuid",
"Rev": "d41af8bb6a7704f00bc3b7cba9355ae6a5a80048" "Rev": "d41af8bb6a7704f00bc3b7cba9355ae6a5a80048"
}, },
{
"ImportPath": "github.com/stretchr/testify/assert",
"Comment": "v1.1.3-19-gd77da35",
"Rev": "d77da356e56a7428ad25149ca77381849a6a5232"
},
{ {
"ImportPath": "github.com/tent/http-link-go", "ImportPath": "github.com/tent/http-link-go",
"Rev": "ac974c61c2f990f4115b119354b5e0b47550e888" "Rev": "ac974c61c2f990f4115b119354b5e0b47550e888"

View File

@ -25,19 +25,24 @@ const BuilderId = "mitchellh.amazon.chroot"
// Config is the configuration that is chained through the steps and // Config is the configuration that is chained through the steps and
// settable from the template. // settable from the template.
type Config struct { type Config struct {
common.PackerConfig `mapstructure:",squash"` common.PackerConfig `mapstructure:",squash"`
awscommon.AccessConfig `mapstructure:",squash"` awscommon.AMIBlockDevices `mapstructure:",squash"`
awscommon.AMIConfig `mapstructure:",squash"` awscommon.AMIConfig `mapstructure:",squash"`
awscommon.AccessConfig `mapstructure:",squash"`
ChrootMounts [][]string `mapstructure:"chroot_mounts"` ChrootMounts [][]string `mapstructure:"chroot_mounts"`
CommandWrapper string `mapstructure:"command_wrapper"` CommandWrapper string `mapstructure:"command_wrapper"`
CopyFiles []string `mapstructure:"copy_files"` CopyFiles []string `mapstructure:"copy_files"`
DevicePath string `mapstructure:"device_path"` DevicePath string `mapstructure:"device_path"`
MountPath string `mapstructure:"mount_path"` FromScratch bool `mapstructure:"from_scratch"`
SourceAmi string `mapstructure:"source_ami"` MountOptions []string `mapstructure:"mount_options"`
RootVolumeSize int64 `mapstructure:"root_volume_size"` MountPartition int `mapstructure:"mount_partition"`
MountOptions []string `mapstructure:"mount_options"` MountPath string `mapstructure:"mount_path"`
MountPartition int `mapstructure:"mount_partition"` PostMountCommands []string `mapstructure:"post_mount_commands"`
PreMountCommands []string `mapstructure:"pre_mount_commands"`
RootDeviceName string `mapstructure:"root_device_name"`
RootVolumeSize int64 `mapstructure:"root_volume_size"`
SourceAmi string `mapstructure:"source_ami"`
ctx interpolate.Context ctx interpolate.Context
} }
@ -59,6 +64,8 @@ func (b *Builder) Prepare(raws ...interface{}) ([]string, error) {
InterpolateFilter: &interpolate.RenderFilter{ InterpolateFilter: &interpolate.RenderFilter{
Exclude: []string{ Exclude: []string{
"command_wrapper", "command_wrapper",
"post_mount_commands",
"pre_mount_commands",
"mount_path", "mount_path",
}, },
}, },
@ -86,7 +93,7 @@ func (b *Builder) Prepare(raws ...interface{}) ([]string, error) {
} }
} }
if len(b.config.CopyFiles) == 0 { if len(b.config.CopyFiles) == 0 && !b.config.FromScratch {
b.config.CopyFiles = []string{"/etc/resolv.conf"} b.config.CopyFiles = []string{"/etc/resolv.conf"}
} }
@ -102,8 +109,10 @@ func (b *Builder) Prepare(raws ...interface{}) ([]string, error) {
b.config.MountPartition = 1 b.config.MountPartition = 1
} }
// Accumulate any errors // Accumulate any errors or warnings
var errs *packer.MultiError var errs *packer.MultiError
var warns []string
errs = packer.MultiErrorAppend(errs, b.config.AccessConfig.Prepare(&b.config.ctx)...) errs = packer.MultiErrorAppend(errs, b.config.AccessConfig.Prepare(&b.config.ctx)...)
errs = packer.MultiErrorAppend(errs, b.config.AMIConfig.Prepare(&b.config.ctx)...) errs = packer.MultiErrorAppend(errs, b.config.AMIConfig.Prepare(&b.config.ctx)...)
@ -115,16 +124,49 @@ func (b *Builder) Prepare(raws ...interface{}) ([]string, error) {
} }
} }
if b.config.SourceAmi == "" { if b.config.FromScratch {
errs = packer.MultiErrorAppend(errs, errors.New("source_ami is required.")) if b.config.SourceAmi != "" {
warns = append(warns, "source_ami is unused when from_scratch is true")
}
if b.config.RootVolumeSize == 0 {
errs = packer.MultiErrorAppend(
errs, errors.New("root_volume_size is required with from_scratch."))
}
if len(b.config.PreMountCommands) == 0 {
errs = packer.MultiErrorAppend(
errs, errors.New("pre_mount_commands is required with from_scratch."))
}
if b.config.AMIVirtType == "" {
errs = packer.MultiErrorAppend(
errs, errors.New("ami_virtualization_type is required with from_scratch."))
}
if b.config.RootDeviceName == "" {
errs = packer.MultiErrorAppend(
errs, errors.New("root_device_name is required with from_scratch."))
}
if len(b.config.AMIMappings) == 0 {
errs = packer.MultiErrorAppend(
errs, errors.New("ami_block_device_mappings is required with from_scratch."))
}
} else {
if b.config.SourceAmi == "" {
errs = packer.MultiErrorAppend(
errs, errors.New("source_ami is required."))
}
if len(b.config.AMIMappings) != 0 {
warns = append(warns, "ami_block_device_mappings are unused when from_scratch is false")
}
if b.config.RootDeviceName != "" {
warns = append(warns, "root_device_name is unused when from_scratch is false")
}
} }
if errs != nil && len(errs.Errors) > 0 { if errs != nil && len(errs.Errors) > 0 {
return nil, errs return warns, errs
} }
log.Println(common.ScrubConfig(b.config, b.config.AccessKey, b.config.SecretKey)) log.Println(common.ScrubConfig(b.config, b.config.AccessKey, b.config.SecretKey))
return nil, nil return warns, nil
} }
func (b *Builder) Run(ui packer.Ui, hook packer.Hook, cache packer.Cache) (packer.Artifact, error) { func (b *Builder) Run(ui packer.Ui, hook packer.Hook, cache packer.Cache) (packer.Artifact, error) {
@ -161,11 +203,19 @@ func (b *Builder) Run(ui packer.Ui, hook packer.Hook, cache packer.Cache) (packe
ForceDeregister: b.config.AMIForceDeregister, ForceDeregister: b.config.AMIForceDeregister,
}, },
&StepInstanceInfo{}, &StepInstanceInfo{},
&awscommon.StepSourceAMIInfo{ }
SourceAmi: b.config.SourceAmi,
EnhancedNetworking: b.config.AMIEnhancedNetworking, if !b.config.FromScratch {
}, steps = append(steps,
&StepCheckRootDevice{}, &awscommon.StepSourceAMIInfo{
SourceAmi: b.config.SourceAmi,
EnhancedNetworking: b.config.AMIEnhancedNetworking,
},
&StepCheckRootDevice{},
)
}
steps = append(steps,
&StepFlock{}, &StepFlock{},
&StepPrepareDevice{}, &StepPrepareDevice{},
&StepCreateVolume{ &StepCreateVolume{
@ -173,10 +223,16 @@ func (b *Builder) Run(ui packer.Ui, hook packer.Hook, cache packer.Cache) (packe
}, },
&StepAttachVolume{}, &StepAttachVolume{},
&StepEarlyUnflock{}, &StepEarlyUnflock{},
&StepPreMountCommands{
Commands: b.config.PreMountCommands,
},
&StepMountDevice{ &StepMountDevice{
MountOptions: b.config.MountOptions, MountOptions: b.config.MountOptions,
MountPartition: b.config.MountPartition, MountPartition: b.config.MountPartition,
}, },
&StepPostMountCommands{
Commands: b.config.PostMountCommands,
},
&StepMountExtra{}, &StepMountExtra{},
&StepCopyFiles{}, &StepCopyFiles{},
&StepChrootProvision{}, &StepChrootProvision{},
@ -203,7 +259,7 @@ func (b *Builder) Run(ui packer.Ui, hook packer.Hook, cache packer.Cache) (packe
&awscommon.StepCreateTags{ &awscommon.StepCreateTags{
Tags: b.config.AMITags, Tags: b.config.AMITags,
}, },
} )
// Run! // Run!
if b.config.PackerDebug { if b.config.PackerDebug {

View File

@ -0,0 +1,37 @@
package chroot
import (
"fmt"
"github.com/mitchellh/packer/packer"
"github.com/mitchellh/packer/post-processor/shell-local"
"github.com/mitchellh/packer/template/interpolate"
)
func RunLocalCommands(commands []string, wrappedCommand CommandWrapper, ctx interpolate.Context, ui packer.Ui) error {
for _, rawCmd := range commands {
intCmd, err := interpolate.Render(rawCmd, &ctx)
if err != nil {
return fmt.Errorf("Error interpolating: %s", err)
}
command, err := wrappedCommand(intCmd)
if err != nil {
return fmt.Errorf("Error wrapping command: %s", err)
}
ui.Say(fmt.Sprintf("Executing command: %s", command))
comm := &shell_local.Communicator{}
cmd := &packer.RemoteCmd{Command: command}
if err := cmd.StartWithUi(comm, ui); err != nil {
return fmt.Errorf("Error executing command: %s", err)
}
if cmd.ExitStatus != 0 {
return fmt.Errorf(
"Received non-zero exit code %d from command: %s",
cmd.ExitStatus,
command)
}
}
return nil
}

View File

@ -22,40 +22,52 @@ type StepCreateVolume struct {
} }
func (s *StepCreateVolume) Run(state multistep.StateBag) multistep.StepAction { func (s *StepCreateVolume) Run(state multistep.StateBag) multistep.StepAction {
config := state.Get("config").(*Config)
ec2conn := state.Get("ec2").(*ec2.EC2) ec2conn := state.Get("ec2").(*ec2.EC2)
image := state.Get("source_image").(*ec2.Image)
instance := state.Get("instance").(*ec2.Instance) instance := state.Get("instance").(*ec2.Instance)
ui := state.Get("ui").(packer.Ui) ui := state.Get("ui").(packer.Ui)
// Determine the root device snapshot var createVolume *ec2.CreateVolumeInput
log.Printf("Searching for root device of the image (%s)", *image.RootDeviceName) if config.FromScratch {
var rootDevice *ec2.BlockDeviceMapping createVolume = &ec2.CreateVolumeInput{
for _, device := range image.BlockDeviceMappings { AvailabilityZone: instance.Placement.AvailabilityZone,
if *device.DeviceName == *image.RootDeviceName { Size: aws.Int64(s.RootVolumeSize),
rootDevice = device VolumeType: aws.String(ec2.VolumeTypeGp2),
break }
} else {
// Determine the root device snapshot
image := state.Get("source_image").(*ec2.Image)
log.Printf("Searching for root device of the image (%s)", *image.RootDeviceName)
var rootDevice *ec2.BlockDeviceMapping
for _, device := range image.BlockDeviceMappings {
if *device.DeviceName == *image.RootDeviceName {
rootDevice = device
break
}
}
if rootDevice == nil {
err := fmt.Errorf("Couldn't find root device!")
state.Put("error", err)
ui.Error(err.Error())
return multistep.ActionHalt
}
ui.Say("Creating the root volume...")
vs := *rootDevice.Ebs.VolumeSize
if s.RootVolumeSize > *rootDevice.Ebs.VolumeSize {
vs = s.RootVolumeSize
}
createVolume = &ec2.CreateVolumeInput{
AvailabilityZone: instance.Placement.AvailabilityZone,
Size: aws.Int64(vs),
SnapshotId: rootDevice.Ebs.SnapshotId,
VolumeType: rootDevice.Ebs.VolumeType,
Iops: rootDevice.Ebs.Iops,
} }
} }
if rootDevice == nil {
err := fmt.Errorf("Couldn't find root device!")
state.Put("error", err)
ui.Error(err.Error())
return multistep.ActionHalt
}
ui.Say("Creating the root volume...")
vs := *rootDevice.Ebs.VolumeSize
if s.RootVolumeSize > *rootDevice.Ebs.VolumeSize {
vs = s.RootVolumeSize
}
createVolume := &ec2.CreateVolumeInput{
AvailabilityZone: instance.Placement.AvailabilityZone,
Size: aws.Int64(vs),
SnapshotId: rootDevice.Ebs.SnapshotId,
VolumeType: rootDevice.Ebs.VolumeType,
Iops: rootDevice.Ebs.Iops,
}
log.Printf("Create args: %+v", createVolume) log.Printf("Create args: %+v", createVolume)
createVolumeResp, err := ec2conn.CreateVolume(createVolume) createVolumeResp, err := ec2conn.CreateVolume(createVolume)

View File

@ -33,10 +33,18 @@ type StepMountDevice struct {
func (s *StepMountDevice) Run(state multistep.StateBag) multistep.StepAction { func (s *StepMountDevice) Run(state multistep.StateBag) multistep.StepAction {
config := state.Get("config").(*Config) config := state.Get("config").(*Config)
ui := state.Get("ui").(packer.Ui) ui := state.Get("ui").(packer.Ui)
image := state.Get("source_image").(*ec2.Image)
device := state.Get("device").(string) device := state.Get("device").(string)
wrappedCommand := state.Get("wrappedCommand").(CommandWrapper) wrappedCommand := state.Get("wrappedCommand").(CommandWrapper)
var virtualizationType string
if config.FromScratch {
virtualizationType = config.AMIVirtType
} else {
image := state.Get("source_image").(*ec2.Image)
virtualizationType = *image.VirtualizationType
log.Printf("Source image virtualization type is: %s", virtualizationType)
}
ctx := config.ctx ctx := config.ctx
ctx.Data = &mountPathData{Device: filepath.Base(device)} ctx.Data = &mountPathData{Device: filepath.Base(device)}
mountPath, err := interpolate.Render(config.MountPath, &ctx) mountPath, err := interpolate.Render(config.MountPath, &ctx)
@ -65,9 +73,8 @@ func (s *StepMountDevice) Run(state multistep.StateBag) multistep.StepAction {
return multistep.ActionHalt return multistep.ActionHalt
} }
log.Printf("Source image virtualization type is: %s", *image.VirtualizationType)
deviceMount := device deviceMount := device
if *image.VirtualizationType == "hvm" { if virtualizationType == "hvm" {
deviceMount = fmt.Sprintf("%s%d", device, s.MountPartition) deviceMount = fmt.Sprintf("%s%d", device, s.MountPartition)
} }
state.Put("deviceMount", deviceMount) state.Put("deviceMount", deviceMount)

View File

@ -0,0 +1,45 @@
package chroot
import (
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/packer"
)
type postMountCommandsData struct {
Device string
MountPath string
}
// StepPostMountCommands allows running arbitrary commands after mounting the
// device, but prior to the bind mount and copy steps.
type StepPostMountCommands struct {
Commands []string
}
func (s *StepPostMountCommands) Run(state multistep.StateBag) multistep.StepAction {
config := state.Get("config").(*Config)
device := state.Get("device").(string)
mountPath := state.Get("mount_path").(string)
ui := state.Get("ui").(packer.Ui)
wrappedCommand := state.Get("wrappedCommand").(CommandWrapper)
if len(s.Commands) == 0 {
return multistep.ActionContinue
}
ctx := config.ctx
ctx.Data = &postMountCommandsData{
Device: device,
MountPath: mountPath,
}
ui.Say("Running post-mount commands...")
if err := RunLocalCommands(s.Commands, wrappedCommand, ctx, ui); err != nil {
state.Put("error", err)
ui.Error(err.Error())
return multistep.ActionHalt
}
return multistep.ActionContinue
}
func (s *StepPostMountCommands) Cleanup(state multistep.StateBag) {}

View File

@ -0,0 +1,39 @@
package chroot
import (
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/packer"
)
type preMountCommandsData struct {
Device string
}
// StepPreMountCommands sets up the a new block device when building from scratch
type StepPreMountCommands struct {
Commands []string
}
func (s *StepPreMountCommands) Run(state multistep.StateBag) multistep.StepAction {
config := state.Get("config").(*Config)
device := state.Get("device").(string)
ui := state.Get("ui").(packer.Ui)
wrappedCommand := state.Get("wrappedCommand").(CommandWrapper)
if len(s.Commands) == 0 {
return multistep.ActionContinue
}
ctx := config.ctx
ctx.Data = &preMountCommandsData{Device: device}
ui.Say("Running device setup commands...")
if err := RunLocalCommands(s.Commands, wrappedCommand, ctx, ui); err != nil {
state.Put("error", err)
ui.Error(err.Error())
return multistep.ActionHalt
}
return multistep.ActionContinue
}
func (s *StepPreMountCommands) Cleanup(state multistep.StateBag) {}

View File

@ -18,22 +18,36 @@ type StepRegisterAMI struct {
func (s *StepRegisterAMI) Run(state multistep.StateBag) multistep.StepAction { func (s *StepRegisterAMI) Run(state multistep.StateBag) multistep.StepAction {
config := state.Get("config").(*Config) config := state.Get("config").(*Config)
ec2conn := state.Get("ec2").(*ec2.EC2) ec2conn := state.Get("ec2").(*ec2.EC2)
image := state.Get("source_image").(*ec2.Image)
snapshotId := state.Get("snapshot_id").(string) snapshotId := state.Get("snapshot_id").(string)
ui := state.Get("ui").(packer.Ui) ui := state.Get("ui").(packer.Ui)
ui.Say("Registering the AMI...") ui.Say("Registering the AMI...")
blockDevices := make([]*ec2.BlockDeviceMapping, len(image.BlockDeviceMappings))
for i, device := range image.BlockDeviceMappings { var (
registerOpts *ec2.RegisterImageInput
blockDevices []*ec2.BlockDeviceMapping
image *ec2.Image
rootDeviceName string
)
if config.FromScratch {
blockDevices = config.AMIBlockDevices.BuildAMIDevices()
rootDeviceName = config.RootDeviceName
} else {
image = state.Get("source_image").(*ec2.Image)
blockDevices = make([]*ec2.BlockDeviceMapping, len(image.BlockDeviceMappings))
rootDeviceName = *image.RootDeviceName
}
for i, device := range blockDevices {
newDevice := device newDevice := device
if *newDevice.DeviceName == *image.RootDeviceName { if *newDevice.DeviceName == rootDeviceName {
if newDevice.Ebs != nil { if newDevice.Ebs != nil {
newDevice.Ebs.SnapshotId = aws.String(snapshotId) newDevice.Ebs.SnapshotId = aws.String(snapshotId)
} else { } else {
newDevice.Ebs = &ec2.EbsBlockDevice{SnapshotId: aws.String(snapshotId)} newDevice.Ebs = &ec2.EbsBlockDevice{SnapshotId: aws.String(snapshotId)}
} }
if s.RootVolumeSize > *newDevice.Ebs.VolumeSize { if config.FromScratch || s.RootVolumeSize > *newDevice.Ebs.VolumeSize {
newDevice.Ebs.VolumeSize = aws.Int64(s.RootVolumeSize) newDevice.Ebs.VolumeSize = aws.Int64(s.RootVolumeSize)
} }
} }
@ -47,7 +61,17 @@ func (s *StepRegisterAMI) Run(state multistep.StateBag) multistep.StepAction {
blockDevices[i] = newDevice blockDevices[i] = newDevice
} }
registerOpts := buildRegisterOpts(config, image, blockDevices) if config.FromScratch {
registerOpts = &ec2.RegisterImageInput{
Name: &config.AMIName,
Architecture: aws.String(ec2.ArchitectureValuesX8664),
RootDeviceName: aws.String(rootDeviceName),
VirtualizationType: aws.String(config.AMIVirtType),
BlockDeviceMappings: blockDevices,
}
} else {
registerOpts = buildRegisterOpts(config, image, blockDevices)
}
// Set SriovNetSupport to "simple". See http://goo.gl/icuXh5 // Set SriovNetSupport to "simple". See http://goo.gl/icuXh5
if config.AMIEnhancedNetworking { if config.AMIEnhancedNetworking {
@ -105,6 +129,5 @@ func buildRegisterOpts(config *Config, image *ec2.Image, blockDevices []*ec2.Blo
registerOpts.KernelId = image.KernelId registerOpts.KernelId = image.KernelId
registerOpts.RamdiskId = image.RamdiskId registerOpts.RamdiskId = image.RamdiskId
} }
return registerOpts return registerOpts
} }

View File

@ -22,7 +22,15 @@ type BlockDevice struct {
} }
type BlockDevices struct { type BlockDevices struct {
AMIMappings []BlockDevice `mapstructure:"ami_block_device_mappings"` AMIBlockDevices `mapstructure:",squash"`
LaunchBlockDevices `mapstructure:",squash"`
}
type AMIBlockDevices struct {
AMIMappings []BlockDevice `mapstructure:"ami_block_device_mappings"`
}
type LaunchBlockDevices struct {
LaunchMappings []BlockDevice `mapstructure:"launch_block_device_mappings"` LaunchMappings []BlockDevice `mapstructure:"launch_block_device_mappings"`
} }
@ -77,10 +85,10 @@ func (b *BlockDevices) Prepare(ctx *interpolate.Context) []error {
return nil return nil
} }
func (b *BlockDevices) BuildAMIDevices() []*ec2.BlockDeviceMapping { func (b *AMIBlockDevices) BuildAMIDevices() []*ec2.BlockDeviceMapping {
return buildBlockDevices(b.AMIMappings) return buildBlockDevices(b.AMIMappings)
} }
func (b *BlockDevices) BuildLaunchDevices() []*ec2.BlockDeviceMapping { func (b *LaunchBlockDevices) BuildLaunchDevices() []*ec2.BlockDeviceMapping {
return buildBlockDevices(b.LaunchMappings) return buildBlockDevices(b.LaunchMappings)
} }

View File

@ -124,22 +124,26 @@ func TestBlockDevice(t *testing.T) {
} }
for _, tc := range cases { for _, tc := range cases {
blockDevices := BlockDevices{ amiBlockDevices := AMIBlockDevices{
AMIMappings: []BlockDevice{*tc.Config}, AMIMappings: []BlockDevice{*tc.Config},
}
launchBlockDevices := LaunchBlockDevices{
LaunchMappings: []BlockDevice{*tc.Config}, LaunchMappings: []BlockDevice{*tc.Config},
} }
expected := []*ec2.BlockDeviceMapping{tc.Result} expected := []*ec2.BlockDeviceMapping{tc.Result}
got := blockDevices.BuildAMIDevices()
if !reflect.DeepEqual(expected, got) { amiResults := amiBlockDevices.BuildAMIDevices()
if !reflect.DeepEqual(expected, amiResults) {
t.Fatalf("Bad block device, \nexpected: %#v\n\ngot: %#v", t.Fatalf("Bad block device, \nexpected: %#v\n\ngot: %#v",
expected, got) expected, amiResults)
} }
if !reflect.DeepEqual(expected, blockDevices.BuildLaunchDevices()) { launchResults := launchBlockDevices.BuildLaunchDevices()
if !reflect.DeepEqual(expected, launchResults) {
t.Fatalf("Bad block device, \nexpected: %#v\n\ngot: %#v", t.Fatalf("Bad block device, \nexpected: %#v\n\ngot: %#v",
expected, expected, launchResults)
blockDevices.BuildLaunchDevices())
} }
} }
} }

View File

@ -10,17 +10,28 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
type ec2Describer interface {
DescribeInstances(*ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error)
}
var (
// modified in tests
sshHostSleepDuration = time.Second
)
// SSHHost returns a function that can be given to the SSH communicator // SSHHost returns a function that can be given to the SSH communicator
// for determining the SSH address based on the instance DNS name. // for determining the SSH address based on the instance DNS name.
func SSHHost(e *ec2.EC2, private bool) func(multistep.StateBag) (string, error) { func SSHHost(e ec2Describer, private bool) func(multistep.StateBag) (string, error) {
return func(state multistep.StateBag) (string, error) { return func(state multistep.StateBag) (string, error) {
for j := 0; j < 2; j++ { const tries = 2
// <= with current structure to check result of describing `tries` times
for j := 0; j <= tries; j++ {
var host string var host string
i := state.Get("instance").(*ec2.Instance) i := state.Get("instance").(*ec2.Instance)
if i.VpcId != nil && *i.VpcId != "" { if i.VpcId != nil && *i.VpcId != "" {
if i.PublicIpAddress != nil && *i.PublicIpAddress != "" && !private { if i.PublicIpAddress != nil && *i.PublicIpAddress != "" && !private {
host = *i.PublicIpAddress host = *i.PublicIpAddress
} else { } else if i.PrivateIpAddress != nil && *i.PrivateIpAddress != "" {
host = *i.PrivateIpAddress host = *i.PrivateIpAddress
} }
} else if i.PublicDnsName != nil && *i.PublicDnsName != "" { } else if i.PublicDnsName != nil && *i.PublicDnsName != "" {
@ -42,8 +53,8 @@ func SSHHost(e *ec2.EC2, private bool) func(multistep.StateBag) (string, error)
return "", fmt.Errorf("instance not found: %s", *i.InstanceId) return "", fmt.Errorf("instance not found: %s", *i.InstanceId)
} }
state.Put("instance", &r.Reservations[0].Instances[0]) state.Put("instance", r.Reservations[0].Instances[0])
time.Sleep(1 * time.Second) time.Sleep(sshHostSleepDuration)
} }
return "", errors.New("couldn't determine IP address for instance") return "", errors.New("couldn't determine IP address for instance")

View File

@ -0,0 +1,118 @@
package common
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/mitchellh/multistep"
)
const (
privateIP = "10.0.0.1"
publicIP = "192.168.1.1"
publicDNS = "public.dns.test"
)
func TestSSHHost(t *testing.T) {
origSshHostSleepDuration := sshHostSleepDuration
defer func() { sshHostSleepDuration = origSshHostSleepDuration }()
sshHostSleepDuration = 0
var cases = []struct {
allowTries int
vpcId string
private bool
ok bool
wantHost string
}{
{1, "", false, true, publicDNS},
{1, "", true, true, publicDNS},
{1, "vpc-id", false, true, publicIP},
{1, "vpc-id", true, true, privateIP},
{2, "", false, true, publicDNS},
{2, "", true, true, publicDNS},
{2, "vpc-id", false, true, publicIP},
{2, "vpc-id", true, true, privateIP},
{3, "", false, false, ""},
{3, "", true, false, ""},
{3, "vpc-id", false, false, ""},
{3, "vpc-id", true, false, ""},
}
for _, c := range cases {
testSSHHost(t, c.allowTries, c.vpcId, c.private, c.ok, c.wantHost)
}
}
func testSSHHost(t *testing.T, allowTries int, vpcId string, private, ok bool, wantHost string) {
t.Logf("allowTries=%d vpcId=%s private=%t ok=%t wantHost=%q", allowTries, vpcId, private, ok, wantHost)
e := &fakeEC2Describer{
allowTries: allowTries,
vpcId: vpcId,
privateIP: privateIP,
publicIP: publicIP,
publicDNS: publicDNS,
}
f := SSHHost(e, private)
st := &multistep.BasicStateBag{}
st.Put("instance", &ec2.Instance{
InstanceId: aws.String("instance-id"),
})
host, err := f(st)
if e.tries > allowTries {
t.Fatalf("got %d ec2 DescribeInstances tries, want %d", e.tries, allowTries)
}
switch {
case ok && err != nil:
t.Fatalf("expected no error, got %+v", err)
case !ok && err == nil:
t.Fatalf("expected error, got none and host %s", host)
}
if host != wantHost {
t.Fatalf("got host %s, want %s", host, wantHost)
}
}
type fakeEC2Describer struct {
allowTries int
tries int
vpcId string
privateIP, publicIP, publicDNS string
}
func (d *fakeEC2Describer) DescribeInstances(in *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) {
d.tries++
instance := &ec2.Instance{
InstanceId: aws.String("instance-id"),
}
if d.vpcId != "" {
instance.VpcId = aws.String(d.vpcId)
}
if d.tries >= d.allowTries {
instance.PublicIpAddress = aws.String(d.publicIP)
instance.PrivateIpAddress = aws.String(d.privateIP)
instance.PublicDnsName = aws.String(d.publicDNS)
}
out := &ec2.DescribeInstancesOutput{
Reservations: []*ec2.Reservation{
{
Instances: []*ec2.Instance{instance},
},
},
}
return out, nil
}

View File

@ -20,7 +20,7 @@ func (s *stepSnapshot) Run(state multistep.StateBag) multistep.StepAction {
dropletId := state.Get("droplet_id").(int) dropletId := state.Get("droplet_id").(int)
ui.Say(fmt.Sprintf("Creating snapshot: %v", c.SnapshotName)) ui.Say(fmt.Sprintf("Creating snapshot: %v", c.SnapshotName))
_, _, err := client.DropletActions.Snapshot(dropletId, c.SnapshotName) action, _, err := client.DropletActions.Snapshot(dropletId, c.SnapshotName)
if err != nil { if err != nil {
err := fmt.Errorf("Error creating snapshot: %s", err) err := fmt.Errorf("Error creating snapshot: %s", err)
state.Put("error", err) state.Put("error", err)
@ -28,6 +28,17 @@ func (s *stepSnapshot) Run(state multistep.StateBag) multistep.StepAction {
return multistep.ActionHalt return multistep.ActionHalt
} }
// With the pending state over, verify that we're in the active state
ui.Say("Waiting for snapshot to complete...")
if err := waitForActionState(godo.ActionCompleted, dropletId, action.ID,
client, 20*time.Minute); err != nil {
// If we get an error the first time, actually report it
err := fmt.Errorf("Error waiting for snapshot: %s", err)
state.Put("error", err)
ui.Error(err.Error())
return multistep.ActionHalt
}
// Wait for the droplet to become unlocked first. For snapshots // Wait for the droplet to become unlocked first. For snapshots
// this can end up taking quite a long time, so we hardcode this to // this can end up taking quite a long time, so we hardcode this to
// 20 minutes. // 20 minutes.
@ -39,16 +50,6 @@ func (s *stepSnapshot) Run(state multistep.StateBag) multistep.StepAction {
return multistep.ActionHalt return multistep.ActionHalt
} }
// With the pending state over, verify that we're in the active state
ui.Say("Waiting for snapshot to complete...")
err = waitForDropletState("active", dropletId, client, c.StateTimeout)
if err != nil {
err := fmt.Errorf("Error waiting for snapshot to complete: %s", err)
state.Put("error", err)
ui.Error(err.Error())
return multistep.ActionHalt
}
log.Printf("Looking up snapshot ID for snapshot: %s", c.SnapshotName) log.Printf("Looking up snapshot ID for snapshot: %s", c.SnapshotName)
images, _, err := client.Droplets.Snapshots(dropletId, nil) images, _, err := client.Droplets.Snapshots(dropletId, nil)
if err != nil { if err != nil {

View File

@ -57,7 +57,7 @@ func waitForDropletUnlocked(
} }
} }
// waitForState simply blocks until the droplet is in // waitForDropletState simply blocks until the droplet is in
// a state we expect, while eventually timing out. // a state we expect, while eventually timing out.
func waitForDropletState( func waitForDropletState(
desiredState string, dropletId int, desiredState string, dropletId int,
@ -106,3 +106,53 @@ func waitForDropletState(
return err return err
} }
} }
// waitForActionState simply blocks until the droplet action is in
// a state we expect, while eventually timing out.
func waitForActionState(
desiredState string, dropletId, actionId int,
client *godo.Client, timeout time.Duration) error {
done := make(chan struct{})
defer close(done)
result := make(chan error, 1)
go func() {
attempts := 0
for {
attempts += 1
log.Printf("Checking action status... (attempt: %d)", attempts)
action, _, err := client.DropletActions.Get(dropletId, actionId)
if err != nil {
result <- err
return
}
if action.Status == desiredState {
result <- nil
return
}
// Wait 3 seconds in between
time.Sleep(3 * time.Second)
// Verify we shouldn't exit
select {
case <-done:
// We finished, so just exit the goroutine
return
default:
// Keep going
}
}
}()
log.Printf("Waiting for up to %d seconds for action to become %s", timeout/time.Second, desiredState)
select {
case err := <-result:
return err
case <-time.After(timeout):
err := fmt.Errorf("Timeout while waiting to for action to become '%s'", desiredState)
return err
}
}

View File

@ -2,7 +2,6 @@ package docker
import ( import (
"archive/tar" "archive/tar"
"bytes"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -10,12 +9,10 @@ import (
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"strconv" "strings"
"sync" "sync"
"syscall" "syscall"
"time"
"github.com/ActiveState/tail"
"github.com/hashicorp/go-version" "github.com/hashicorp/go-version"
"github.com/mitchellh/packer/packer" "github.com/mitchellh/packer/packer"
) )
@ -30,42 +27,35 @@ type Communicator struct {
} }
func (c *Communicator) Start(remote *packer.RemoteCmd) error { func (c *Communicator) Start(remote *packer.RemoteCmd) error {
// Create a temporary file to store the output. Because of a bug in var cmd *exec.Cmd
// Docker, sometimes all the output doesn't properly show up. This if c.Config.Pty {
// file will capture ALL of the output, and we'll read that. cmd = exec.Command("docker", "exec", "-i", "-t", c.ContainerId, "/bin/sh", "-c", fmt.Sprintf("(%s)", remote.Command))
// } else {
// https://github.com/dotcloud/docker/issues/2625 cmd = exec.Command("docker", "exec", "-i", c.ContainerId, "/bin/sh", "-c", fmt.Sprintf("(%s)", remote.Command))
outputFile, err := ioutil.TempFile(c.HostDir, "cmd") }
var (
stdin_w io.WriteCloser
err error
)
stdin_w, err = cmd.StdinPipe()
if err != nil { if err != nil {
return err return err
} }
outputFile.Close()
// This file will store the exit code of the command once it is complete. stderr_r, err := cmd.StderrPipe()
exitCodePath := outputFile.Name() + "-exit" if err != nil {
return err
var cmd *exec.Cmd
if c.canExec() {
if c.Config.Pty {
cmd = exec.Command("docker", "exec", "-i", "-t", c.ContainerId, "/bin/sh")
} else {
cmd = exec.Command("docker", "exec", "-i", c.ContainerId, "/bin/sh")
}
} else {
cmd = exec.Command("docker", "attach", c.ContainerId)
} }
stdin_w, err := cmd.StdinPipe() stdout_r, err := cmd.StdoutPipe()
if err != nil { if err != nil {
// We have to do some cleanup since run was never called
os.Remove(outputFile.Name())
os.Remove(exitCodePath)
return err return err
} }
// Run the actual command in a goroutine so that Start doesn't block // Run the actual command in a goroutine so that Start doesn't block
go c.run(cmd, remote, stdin_w, outputFile, exitCodePath) go c.run(cmd, remote, stdin_w, stdout_r, stderr_r)
return nil return nil
} }
@ -237,105 +227,51 @@ func (c *Communicator) DownloadDir(src string, dst string, exclude []string) err
return fmt.Errorf("DownloadDir is not implemented for docker") return fmt.Errorf("DownloadDir is not implemented for docker")
} }
// canExec tells us whether `docker exec` is supported
func (c *Communicator) canExec() bool {
execConstraint, err := version.NewConstraint(">= 1.4.0")
if err != nil {
panic(err)
}
return execConstraint.Check(c.Version)
}
// Runs the given command and blocks until completion // Runs the given command and blocks until completion
func (c *Communicator) run(cmd *exec.Cmd, remote *packer.RemoteCmd, stdin_w io.WriteCloser, outputFile *os.File, exitCodePath string) { func (c *Communicator) run(cmd *exec.Cmd, remote *packer.RemoteCmd, stdin io.WriteCloser, stdout, stderr io.ReadCloser) {
// For Docker, remote communication must be serialized since it // For Docker, remote communication must be serialized since it
// only supports single execution. // only supports single execution.
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
// Clean up after ourselves by removing our temporary files wg := sync.WaitGroup{}
defer os.Remove(outputFile.Name()) repeat := func(w io.Writer, r io.ReadCloser) {
defer os.Remove(exitCodePath) io.Copy(w, r)
r.Close()
// Tail the output file and send the data to the stdout listener wg.Done()
tail, err := tail.TailFile(outputFile.Name(), tail.Config{
Poll: true,
ReOpen: true,
Follow: true,
})
if err != nil {
log.Printf("Error tailing output file: %s", err)
remote.SetExited(254)
return
} }
defer tail.Stop()
// Modify the remote command so that all the output of the commands if remote.Stdout != nil {
// go to a single file and so that the exit code is redirected to wg.Add(1)
// a single file. This lets us determine both when the command go repeat(remote.Stdout, stdout)
// is truly complete (because the file will have data), what the }
// exit status is (because Docker loses it because of the pty, not
// Docker's fault), and get the output (Docker bug). if remote.Stderr != nil {
remoteCmd := fmt.Sprintf("(%s) >%s 2>&1; echo $? >%s", wg.Add(1)
remote.Command, go repeat(remote.Stderr, stderr)
filepath.Join(c.ContainerDir, filepath.Base(outputFile.Name())), }
filepath.Join(c.ContainerDir, filepath.Base(exitCodePath)))
// Start the command // Start the command
log.Printf("Executing in container %s: %#v", c.ContainerId, remoteCmd) log.Printf("Executing %s:", strings.Join(cmd.Args, " "))
if err := cmd.Start(); err != nil { if err := cmd.Start(); err != nil {
log.Printf("Error executing: %s", err) log.Printf("Error executing: %s", err)
remote.SetExited(254) remote.SetExited(254)
return return
} }
go func() {
defer stdin_w.Close()
// This sleep needs to be here because of the issue linked to below.
// Basically, without it, Docker will hang on reading stdin forever,
// and won't see what we write, for some reason.
//
// https://github.com/dotcloud/docker/issues/2628
time.Sleep(2 * time.Second)
stdin_w.Write([]byte(remoteCmd + "\n"))
}()
// Start a goroutine to read all the lines out of the logs. These channels
// allow us to stop the go-routine and wait for it to be stopped.
stopTailCh := make(chan struct{})
doneCh := make(chan struct{})
go func() {
defer close(doneCh)
for {
select {
case <-tail.Dead():
return
case line := <-tail.Lines:
if remote.Stdout != nil {
remote.Stdout.Write([]byte(line.Text + "\n"))
} else {
log.Printf("Command stdout: %#v", line.Text)
}
case <-time.After(2 * time.Second):
// If we're done, then return. Otherwise, keep grabbing
// data. This gives us a chance to flush all the lines
// out of the tailed file.
select {
case <-stopTailCh:
return
default:
}
}
}
}()
var exitRaw []byte
var exitStatus int var exitStatus int
var exitStatusRaw int64
err = cmd.Wait() if remote.Stdin != nil {
go func() {
io.Copy(stdin, remote.Stdin)
// close stdin to support commands that wait for stdin to be closed before exiting.
stdin.Close()
}()
}
wg.Wait()
err := cmd.Wait()
if exitErr, ok := err.(*exec.ExitError); ok { if exitErr, ok := err.(*exec.ExitError); ok {
exitStatus = 1 exitStatus = 1
@ -344,45 +280,8 @@ func (c *Communicator) run(cmd *exec.Cmd, remote *packer.RemoteCmd, stdin_w io.W
if status, ok := exitErr.Sys().(syscall.WaitStatus); ok { if status, ok := exitErr.Sys().(syscall.WaitStatus); ok {
exitStatus = status.ExitStatus() exitStatus = status.ExitStatus()
} }
// Say that we ended, since if Docker itself failed, then
// the command must've not run, or so we assume
goto REMOTE_EXIT
} }
// Wait for the exit code to appear in our file...
log.Println("Waiting for exit code to appear for remote command...")
for {
fi, err := os.Stat(exitCodePath)
if err == nil && fi.Size() > 0 {
break
}
time.Sleep(1 * time.Second)
}
// Read the exit code
exitRaw, err = ioutil.ReadFile(exitCodePath)
if err != nil {
log.Printf("Error executing: %s", err)
exitStatus = 254
goto REMOTE_EXIT
}
exitStatusRaw, err = strconv.ParseInt(string(bytes.TrimSpace(exitRaw)), 10, 0)
if err != nil {
log.Printf("Error executing: %s", err)
exitStatus = 254
goto REMOTE_EXIT
}
exitStatus = int(exitStatusRaw)
log.Printf("Executed command exit status: %d", exitStatus)
REMOTE_EXIT:
// Wait for the tail to finish
close(stopTailCh)
<-doneCh
// Set the exit status which triggers waiters // Set the exit status which triggers waiters
remote.SetExited(exitStatus) remote.SetExited(exitStatus)
} }

View File

@ -35,11 +35,13 @@ type Config struct {
// This is used to login to dockerhub to pull a private base container. For // This is used to login to dockerhub to pull a private base container. For
// pushing to dockerhub, see the docker post-processors // pushing to dockerhub, see the docker post-processors
Login bool Login bool
LoginEmail string `mapstructure:"login_email"` LoginEmail string `mapstructure:"login_email"`
LoginPassword string `mapstructure:"login_password"` LoginPassword string `mapstructure:"login_password"`
LoginServer string `mapstructure:"login_server"` LoginServer string `mapstructure:"login_server"`
LoginUsername string `mapstructure:"login_username"` LoginUsername string `mapstructure:"login_username"`
EcrLogin bool `mapstructure:"ecr_login"`
AwsAccessConfig `mapstructure:",squash"`
ctx interpolate.Context ctx interpolate.Context
} }
@ -107,6 +109,10 @@ func NewConfig(raws ...interface{}) (*Config, []string, error) {
} }
} }
if c.EcrLogin && c.LoginServer == "" {
errs = packer.MultiErrorAppend(errs, fmt.Errorf("ECR login requires login server to be provided."))
}
if errs != nil && len(errs.Errors) > 0 { if errs != nil && len(errs.Errors) > 0 {
return nil, nil, errs return nil, nil, errs
} }

View File

@ -0,0 +1,88 @@
package docker
import (
"encoding/base64"
"fmt"
"log"
"regexp"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ecr"
)
type AwsAccessConfig struct {
AccessKey string `mapstructure:"aws_access_key"`
SecretKey string `mapstructure:"aws_secret_key"`
Token string `mapstructure:"aws_token"`
}
// Config returns a valid aws.Config object for access to AWS services, or
// an error if the authentication and region couldn't be resolved
func (c *AwsAccessConfig) config(region string) (*aws.Config, error) {
var creds *credentials.Credentials
config := aws.NewConfig().WithRegion(region).WithMaxRetries(11)
sess := session.New(config)
creds = credentials.NewChainCredentials([]credentials.Provider{
&credentials.StaticProvider{Value: credentials.Value{
AccessKeyID: c.AccessKey,
SecretAccessKey: c.SecretKey,
SessionToken: c.Token,
}},
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{Filename: "", Profile: ""},
&ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(sess),
},
})
return config.WithCredentials(creds), nil
}
// Get a login token for Amazon AWS ECR. Returns username and password
// or an error.
func (c *AwsAccessConfig) EcrGetLogin(ecrUrl string) (string, string, error) {
exp := regexp.MustCompile("(?:http://|https://|)([0-9]*)\\.dkr\\.ecr\\.(.*)\\.amazonaws\\.com.*")
splitUrl := exp.FindStringSubmatch(ecrUrl)
accountId := splitUrl[1]
region := splitUrl[2]
log.Println(fmt.Sprintf("Getting ECR token for account: %s in %s..", accountId, region))
awsConfig, err := c.config(region)
if err != nil {
return "", "", err
}
session, err := session.NewSession(awsConfig)
if err != nil {
return "", "", fmt.Errorf("failed to create session: %s", err)
}
service := ecr.New(session)
params := &ecr.GetAuthorizationTokenInput{
RegistryIds: []*string{
aws.String(accountId),
},
}
resp, err := service.GetAuthorizationToken(params)
if err != nil {
return "", "", fmt.Errorf(err.Error())
}
auth, err := base64.StdEncoding.DecodeString(*resp.AuthorizationData[0].AuthorizationToken)
if err != nil {
return "", "", fmt.Errorf("Error decoding ECR AuthorizationToken: %s", err)
}
authParts := strings.SplitN(string(auth), ":", 2)
log.Printf("Successfully got login for ECR: %s", ecrUrl)
return authParts[0], authParts[1], nil
}

View File

@ -21,7 +21,22 @@ func (s *StepPull) Run(state multistep.StateBag) multistep.StepAction {
ui.Say(fmt.Sprintf("Pulling Docker image: %s", config.Image)) ui.Say(fmt.Sprintf("Pulling Docker image: %s", config.Image))
if config.Login { if config.EcrLogin {
ui.Message("Fetching ECR credentials...")
username, password, err := config.EcrGetLogin(config.LoginServer)
if err != nil {
err := fmt.Errorf("Error fetching ECR credentials: %s", err)
state.Put("error", err)
ui.Error(err.Error())
return multistep.ActionHalt
}
config.LoginUsername = username
config.LoginPassword = password
}
if config.Login || config.EcrLogin {
ui.Message("Logging in...") ui.Message("Logging in...")
err := driver.Login( err := driver.Login(
config.LoginServer, config.LoginServer,

View File

@ -7,7 +7,7 @@ import (
// Artifact represents a GCE image as the result of a Packer build. // Artifact represents a GCE image as the result of a Packer build.
type Artifact struct { type Artifact struct {
image Image image *Image
driver Driver driver Driver
config *Config config *Config
} }
@ -41,16 +41,16 @@ func (a *Artifact) String() string {
func (a *Artifact) State(name string) interface{} { func (a *Artifact) State(name string) interface{} {
switch name { switch name {
case "ImageName": case "ImageName":
return a.image.Name return a.image.Name
case "ImageSizeGb": case "ImageSizeGb":
return a.image.SizeGb return a.image.SizeGb
case "AccountFilePath": case "AccountFilePath":
return a.config.AccountFile return a.config.AccountFile
case "ProjectId": case "ProjectId":
return a.config.ProjectId return a.config.ProjectId
case "BuildZone": case "BuildZone":
return a.config.Zone return a.config.Zone
} }
return nil return nil
} }

View File

@ -93,7 +93,7 @@ func (b *Builder) Run(ui packer.Ui, hook packer.Hook, cache packer.Cache) (packe
} }
artifact := &Artifact{ artifact := &Artifact{
image: state.Get("image").(Image), image: state.Get("image").(*Image),
driver: driver, driver: driver,
config: b.config, config: b.config,
} }

View File

@ -25,20 +25,20 @@ func Retry(initialInterval float64, maxInterval float64, numTries uint, function
done := false done := false
interval := initialInterval interval := initialInterval
for i := uint(0); !done && (numTries == 0 || i < numTries); i++ { for i := uint(0); !done && (numTries == 0 || i < numTries); i++ {
done, err = function() done, err = function()
if err != nil { if err != nil {
return err return err
} }
if !done { if !done {
// Retry after delay. Calculate next delay. // Retry after delay. Calculate next delay.
time.Sleep(time.Duration(interval) * time.Second) time.Sleep(time.Duration(interval) * time.Second)
interval = math.Min(interval * 2, maxInterval) interval = math.Min(interval*2, maxInterval)
} }
} }
if !done { if !done {
return RetryExhaustedError return RetryExhaustedError
} }
return nil return nil
} }

View File

@ -18,7 +18,7 @@ func TestRetry(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Passing function should not have returned a retry error. Error: %s", err) t.Fatalf("Passing function should not have returned a retry error. Error: %s", err)
} }
// Test that a failing function gets retried (once in this example). // Test that a failing function gets retried (once in this example).
numTries = 0 numTries = 0
results := []bool{false, true} results := []bool{false, true}
@ -33,7 +33,7 @@ func TestRetry(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Successful retried function should not have returned a retry error. Error: %s", err) t.Fatalf("Successful retried function should not have returned a retry error. Error: %s", err)
} }
// Test that a function error gets returned, and the function does not get called again. // Test that a function error gets returned, and the function does not get called again.
numTries = 0 numTries = 0
funcErr := fmt.Errorf("This function had an error!") funcErr := fmt.Errorf("This function had an error!")
@ -47,7 +47,7 @@ func TestRetry(t *testing.T) {
if err != funcErr { if err != funcErr {
t.Fatalf("Errant function did not return the right error %s. Error: %s", funcErr, err) t.Fatalf("Errant function did not return the right error %s. Error: %s", funcErr, err)
} }
// Test when a function exhausts its retries. // Test when a function exhausts its retries.
numTries = 0 numTries = 0
expectedTries := uint(3) expectedTries := uint(3)
@ -61,4 +61,4 @@ func TestRetry(t *testing.T) {
if err != RetryExhaustedError { if err != RetryExhaustedError {
t.Fatalf("Unsuccessful retry function should have returned a retry exhausted error. Actual error: %s", err) t.Fatalf("Unsuccessful retry function should have returned a retry exhausted error. Actual error: %s", err)
} }
} }

View File

@ -191,4 +191,4 @@ func (c *Config) CalcTimeout() error {
} }
c.stateTimeout = stateTimeout c.stateTimeout = stateTimeout
return nil return nil
} }

View File

@ -4,13 +4,9 @@ package googlecompute
// with GCE. The Driver interface exists mostly to allow a mock implementation // with GCE. The Driver interface exists mostly to allow a mock implementation
// to be used to test the steps. // to be used to test the steps.
type Driver interface { type Driver interface {
// ImageExists returns true if the specified image exists. If an error
// occurs calling the API, this method returns false.
ImageExists(name string) bool
// CreateImage creates an image from the given disk in Google Compute // CreateImage creates an image from the given disk in Google Compute
// Engine. // Engine.
CreateImage(name, description, family, zone, disk string) (<-chan Image, <-chan error) CreateImage(name, description, family, zone, disk string) (<-chan *Image, <-chan error)
// DeleteImage deletes the image with the given name. // DeleteImage deletes the image with the given name.
DeleteImage(name string) <-chan error DeleteImage(name string) <-chan error
@ -21,15 +17,28 @@ type Driver interface {
// DeleteDisk deletes the disk with the given name. // DeleteDisk deletes the disk with the given name.
DeleteDisk(zone, name string) (<-chan error, error) DeleteDisk(zone, name string) (<-chan error, error)
// GetImage gets an image; tries the default and public projects.
GetImage(name string) (*Image, error)
// GetImageFromProject gets an image from a specific project.
GetImageFromProject(project, name string) (*Image, error)
// GetInstanceMetadata gets a metadata variable for the instance, name.
GetInstanceMetadata(zone, name, key string) (string, error)
// GetInternalIP gets the GCE-internal IP address for the instance. // GetInternalIP gets the GCE-internal IP address for the instance.
GetInternalIP(zone, name string) (string, error) GetInternalIP(zone, name string) (string, error)
// GetNatIP gets the NAT IP address for the instance. // GetNatIP gets the NAT IP address for the instance.
GetNatIP(zone, name string) (string, error) GetNatIP(zone, name string) (string, error)
// GetSerialPortOutput gets the Serial Port contents for the instance. // GetSerialPortOutput gets the Serial Port contents for the instance.
GetSerialPortOutput(zone, name string) (string, error) GetSerialPortOutput(zone, name string) (string, error)
// ImageExists returns true if the specified image exists. If an error
// occurs calling the API, this method returns false.
ImageExists(name string) bool
// RunInstance takes the given config and launches an instance. // RunInstance takes the given config and launches an instance.
RunInstance(*InstanceConfig) (<-chan error, error) RunInstance(*InstanceConfig) (<-chan error, error)
@ -37,18 +46,12 @@ type Driver interface {
WaitForInstance(state, zone, name string) <-chan error WaitForInstance(state, zone, name string) <-chan error
} }
type Image struct {
Name string
ProjectId string
SizeGb int64
}
type InstanceConfig struct { type InstanceConfig struct {
Address string Address string
Description string Description string
DiskSizeGb int64 DiskSizeGb int64
DiskType string DiskType string
Image Image Image *Image
MachineType string MachineType string
Metadata map[string]string Metadata map[string]string
Name string Name string

View File

@ -5,6 +5,7 @@ import (
"log" "log"
"net/http" "net/http"
"runtime" "runtime"
"strings"
"github.com/mitchellh/packer/packer" "github.com/mitchellh/packer/packer"
"github.com/mitchellh/packer/version" "github.com/mitchellh/packer/version"
@ -13,7 +14,6 @@ import (
"golang.org/x/oauth2/google" "golang.org/x/oauth2/google"
"golang.org/x/oauth2/jwt" "golang.org/x/oauth2/jwt"
"google.golang.org/api/compute/v1" "google.golang.org/api/compute/v1"
"strings"
) )
// driverGCE is a Driver implementation that actually talks to GCE. // driverGCE is a Driver implementation that actually talks to GCE.
@ -88,15 +88,8 @@ func NewDriverGCE(ui packer.Ui, p string, a *AccountFile) (Driver, error) {
}, nil }, nil
} }
func (d *driverGCE) ImageExists(name string) bool { func (d *driverGCE) CreateImage(name, description, family, zone, disk string) (<-chan *Image, <-chan error) {
_, err := d.service.Images.Get(d.projectId, name).Do() gce_image := &compute.Image{
// The API may return an error for reasons other than the image not
// existing, but this heuristic is sufficient for now.
return err == nil
}
func (d *driverGCE) CreateImage(name, description, family, zone, disk string) (<-chan Image, <-chan error) {
image := &compute.Image{
Description: description, Description: description,
Name: name, Name: name,
Family: family, Family: family,
@ -104,9 +97,9 @@ func (d *driverGCE) CreateImage(name, description, family, zone, disk string) (<
SourceType: "RAW", SourceType: "RAW",
} }
imageCh := make(chan Image, 1) imageCh := make(chan *Image, 1)
errCh := make(chan error, 1) errCh := make(chan error, 1)
op, err := d.service.Images.Insert(d.projectId, image).Do() op, err := d.service.Images.Insert(d.projectId, gce_image).Do()
if err != nil { if err != nil {
errCh <- err errCh <- err
} else { } else {
@ -114,17 +107,17 @@ func (d *driverGCE) CreateImage(name, description, family, zone, disk string) (<
err = waitForState(errCh, "DONE", d.refreshGlobalOp(op)) err = waitForState(errCh, "DONE", d.refreshGlobalOp(op))
if err != nil { if err != nil {
close(imageCh) close(imageCh)
errCh <- err
return
} }
image, err = d.getImage(name, d.projectId) var image *Image
image, err = d.GetImageFromProject(d.projectId, name)
if err != nil { if err != nil {
close(imageCh) close(imageCh)
errCh <- err errCh <- err
return
} }
imageCh <- Image{ imageCh <- image
Name: name,
ProjectId: d.projectId,
SizeGb: image.DiskSizeGb,
}
close(imageCh) close(imageCh)
}() }()
} }
@ -166,6 +159,57 @@ func (d *driverGCE) DeleteDisk(zone, name string) (<-chan error, error) {
return errCh, nil return errCh, nil
} }
func (d *driverGCE) GetImage(name string) (*Image, error) {
projects := []string{d.projectId, "centos-cloud", "coreos-cloud", "debian-cloud", "google-containers", "opensuse-cloud", "rhel-cloud", "suse-cloud", "ubuntu-os-cloud", "windows-cloud"}
var errs error
for _, project := range projects {
image, err := d.GetImageFromProject(project, name)
if err != nil {
errs = packer.MultiErrorAppend(errs, err)
}
if image != nil {
return image, nil
}
}
return nil, fmt.Errorf(
"Could not find image, %s, in projects, %s: %s", name,
projects, errs)
}
func (d *driverGCE) GetImageFromProject(project, name string) (*Image, error) {
image, err := d.service.Images.Get(project, name).Do()
if err != nil {
return nil, err
} else if image == nil || image.SelfLink == "" {
return nil, fmt.Errorf("Image, %s, could not be found in project: %s", name, project)
} else {
return &Image{
Licenses: image.Licenses,
Name: image.Name,
ProjectId: project,
SelfLink: image.SelfLink,
SizeGb: image.DiskSizeGb,
}, nil
}
}
func (d *driverGCE) GetInstanceMetadata(zone, name, key string) (string, error) {
instance, err := d.service.Instances.Get(d.projectId, zone, name).Do()
if err != nil {
return "", err
}
for _, item := range instance.Metadata.Items {
if item.Key == key {
return *item.Value, nil
}
}
return "", fmt.Errorf("Instance metadata key, %s, not found.", key)
}
func (d *driverGCE) GetNatIP(zone, name string) (string, error) { func (d *driverGCE) GetNatIP(zone, name string) (string, error) {
instance, err := d.service.Instances.Get(d.projectId, zone, name).Do() instance, err := d.service.Instances.Get(d.projectId, zone, name).Do()
if err != nil { if err != nil {
@ -207,10 +251,17 @@ func (d *driverGCE) GetSerialPortOutput(zone, name string) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
return output.Contents, nil return output.Contents, nil
} }
func (d *driverGCE) ImageExists(name string) bool {
_, err := d.GetImageFromProject(d.projectId, name)
// The API may return an error for reasons other than the image not
// existing, but this heuristic is sufficient for now.
return err == nil
}
func (d *driverGCE) RunInstance(c *InstanceConfig) (<-chan error, error) { func (d *driverGCE) RunInstance(c *InstanceConfig) (<-chan error, error) {
// Get the zone // Get the zone
d.ui.Message(fmt.Sprintf("Loading zone: %s", c.Zone)) d.ui.Message(fmt.Sprintf("Loading zone: %s", c.Zone))
@ -219,13 +270,6 @@ func (d *driverGCE) RunInstance(c *InstanceConfig) (<-chan error, error) {
return nil, err return nil, err
} }
// Get the image
d.ui.Message(fmt.Sprintf("Loading image: %s in project %s", c.Image.Name, c.Image.ProjectId))
image, err := d.getImage(c.Image.Name, c.Image.ProjectId)
if err != nil {
return nil, err
}
// Get the machine type // Get the machine type
d.ui.Message(fmt.Sprintf("Loading machine type: %s", c.MachineType)) d.ui.Message(fmt.Sprintf("Loading machine type: %s", c.MachineType))
machineType, err := d.service.MachineTypes.Get( machineType, err := d.service.MachineTypes.Get(
@ -302,7 +346,7 @@ func (d *driverGCE) RunInstance(c *InstanceConfig) (<-chan error, error) {
Boot: true, Boot: true,
AutoDelete: false, AutoDelete: false,
InitializeParams: &compute.AttachedDiskInitializeParams{ InitializeParams: &compute.AttachedDiskInitializeParams{
SourceImage: image.SelfLink, SourceImage: c.Image.SelfLink,
DiskSizeGb: c.DiskSizeGb, DiskSizeGb: c.DiskSizeGb,
DiskType: fmt.Sprintf("zones/%s/diskTypes/%s", zone.Name, c.DiskType), DiskType: fmt.Sprintf("zones/%s/diskTypes/%s", zone.Name, c.DiskType),
}, },
@ -355,20 +399,6 @@ func (d *driverGCE) WaitForInstance(state, zone, name string) <-chan error {
return errCh return errCh
} }
func (d *driverGCE) getImage(name, projectId string) (image *compute.Image, err error) {
projects := []string{projectId, "centos-cloud", "coreos-cloud", "debian-cloud", "google-containers", "opensuse-cloud", "rhel-cloud", "suse-cloud", "ubuntu-os-cloud", "windows-cloud"}
for _, project := range projects {
image, err = d.service.Images.Get(project, name).Do()
if err == nil && image != nil && image.SelfLink != "" {
return
}
image = nil
}
err = fmt.Errorf("Image %s could not be found in any of these projects: %s", name, projects)
return
}
func (d *driverGCE) refreshInstanceState(zone, name string) stateRefreshFunc { func (d *driverGCE) refreshInstanceState(zone, name string) stateRefreshFunc {
return func() (string, error) { return func() (string, error) {
instance, err := d.service.Instances.Get(d.projectId, zone, name).Do() instance, err := d.service.Instances.Get(d.projectId, zone, name).Do()

View File

@ -1,20 +1,21 @@
package googlecompute package googlecompute
import "fmt"
// DriverMock is a Driver implementation that is a mocked out so that // DriverMock is a Driver implementation that is a mocked out so that
// it can be used for tests. // it can be used for tests.
type DriverMock struct { type DriverMock struct {
ImageExistsName string CreateImageName string
ImageExistsResult bool CreateImageDesc string
CreateImageFamily string
CreateImageName string CreateImageZone string
CreateImageDesc string CreateImageDisk string
CreateImageFamily string CreateImageResultLicenses []string
CreateImageZone string CreateImageResultProjectId string
CreateImageDisk string CreateImageResultSelfLink string
CreateImageProjectId string CreateImageResultSizeGb int64
CreateImageSizeGb int64 CreateImageErrCh <-chan error
CreateImageErrCh <-chan error CreateImageResultCh <-chan *Image
CreateImageResultCh <-chan Image
DeleteImageName string DeleteImageName string
DeleteImageErrCh <-chan error DeleteImageErrCh <-chan error
@ -29,6 +30,21 @@ type DriverMock struct {
DeleteDiskErrCh <-chan error DeleteDiskErrCh <-chan error
DeleteDiskErr error DeleteDiskErr error
GetImageName string
GetImageResult *Image
GetImageErr error
GetImageFromProjectProject string
GetImageFromProjectName string
GetImageFromProjectResult *Image
GetImageFromProjectErr error
GetInstanceMetadataZone string
GetInstanceMetadataName string
GetInstanceMetadataKey string
GetInstanceMetadataResult string
GetInstanceMetadataErr error
GetNatIPZone string GetNatIPZone string
GetNatIPName string GetNatIPName string
GetNatIPResult string GetNatIPResult string
@ -38,12 +54,15 @@ type DriverMock struct {
GetInternalIPName string GetInternalIPName string
GetInternalIPResult string GetInternalIPResult string
GetInternalIPErr error GetInternalIPErr error
GetSerialPortOutputZone string GetSerialPortOutputZone string
GetSerialPortOutputName string GetSerialPortOutputName string
GetSerialPortOutputResult string GetSerialPortOutputResult string
GetSerialPortOutputErr error GetSerialPortOutputErr error
ImageExistsName string
ImageExistsResult bool
RunInstanceConfig *InstanceConfig RunInstanceConfig *InstanceConfig
RunInstanceErrCh <-chan error RunInstanceErrCh <-chan error
RunInstanceErr error RunInstanceErr error
@ -54,31 +73,33 @@ type DriverMock struct {
WaitForInstanceErrCh <-chan error WaitForInstanceErrCh <-chan error
} }
func (d *DriverMock) ImageExists(name string) bool { func (d *DriverMock) CreateImage(name, description, family, zone, disk string) (<-chan *Image, <-chan error) {
d.ImageExistsName = name
return d.ImageExistsResult
}
func (d *DriverMock) CreateImage(name, description, family, zone, disk string) (<-chan Image, <-chan error) {
d.CreateImageName = name d.CreateImageName = name
d.CreateImageDesc = description d.CreateImageDesc = description
d.CreateImageFamily = family d.CreateImageFamily = family
d.CreateImageZone = zone d.CreateImageZone = zone
d.CreateImageDisk = disk d.CreateImageDisk = disk
if d.CreateImageSizeGb == 0 { if d.CreateImageResultProjectId == "" {
d.CreateImageSizeGb = 10 d.CreateImageResultProjectId = "test"
} }
if d.CreateImageProjectId == "" { if d.CreateImageResultSelfLink == "" {
d.CreateImageProjectId = "test" d.CreateImageResultSelfLink = fmt.Sprintf(
"http://content.googleapis.com/compute/v1/%s/global/licenses/test",
d.CreateImageResultProjectId)
}
if d.CreateImageResultSizeGb == 0 {
d.CreateImageResultSizeGb = 10
} }
resultCh := d.CreateImageResultCh resultCh := d.CreateImageResultCh
if resultCh == nil { if resultCh == nil {
ch := make(chan Image, 1) ch := make(chan *Image, 1)
ch <- Image{ ch <- &Image{
Licenses: d.CreateImageResultLicenses,
Name: name, Name: name,
ProjectId: d.CreateImageProjectId, ProjectId: d.CreateImageResultProjectId,
SizeGb: d.CreateImageSizeGb, SelfLink: d.CreateImageResultSelfLink,
SizeGb: d.CreateImageResultSizeGb,
} }
close(ch) close(ch)
resultCh = ch resultCh = ch
@ -135,6 +156,24 @@ func (d *DriverMock) DeleteDisk(zone, name string) (<-chan error, error) {
return resultCh, d.DeleteDiskErr return resultCh, d.DeleteDiskErr
} }
func (d *DriverMock) GetImage(name string) (*Image, error) {
d.GetImageName = name
return d.GetImageResult, d.GetImageErr
}
func (d *DriverMock) GetImageFromProject(project, name string) (*Image, error) {
d.GetImageFromProjectProject = project
d.GetImageFromProjectName = name
return d.GetImageFromProjectResult, d.GetImageFromProjectErr
}
func (d *DriverMock) GetInstanceMetadata(zone, name, key string) (string, error) {
d.GetInstanceMetadataZone = zone
d.GetInstanceMetadataName = name
d.GetInstanceMetadataKey = key
return d.GetInstanceMetadataResult, d.GetInstanceMetadataErr
}
func (d *DriverMock) GetNatIP(zone, name string) (string, error) { func (d *DriverMock) GetNatIP(zone, name string) (string, error) {
d.GetNatIPZone = zone d.GetNatIPZone = zone
d.GetNatIPName = name d.GetNatIPName = name
@ -153,6 +192,11 @@ func (d *DriverMock) GetSerialPortOutput(zone, name string) (string, error) {
return d.GetSerialPortOutputResult, d.GetSerialPortOutputErr return d.GetSerialPortOutputResult, d.GetSerialPortOutputErr
} }
func (d *DriverMock) ImageExists(name string) bool {
d.ImageExistsName = name
return d.ImageExistsResult
}
func (d *DriverMock) RunInstance(c *InstanceConfig) (<-chan error, error) { func (d *DriverMock) RunInstance(c *InstanceConfig) (<-chan error, error) {
d.RunInstanceConfig = c d.RunInstanceConfig = c

View File

@ -0,0 +1,22 @@
package googlecompute
import (
"strings"
)
type Image struct {
Licenses []string
Name string
ProjectId string
SelfLink string
SizeGb int64
}
func (i *Image) IsWindows() bool {
for _, license := range i.Licenses {
if strings.Contains(license, "windows") {
return true
}
}
return false
}

View File

@ -0,0 +1,26 @@
package googlecompute
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func StubImage(name, project string, licenses []string, sizeGb int64) *Image {
return &Image{
Licenses: licenses,
Name: name,
ProjectId: project,
SelfLink: fmt.Sprintf("https://www.googleapis.com/compute/v1/projects/%s/global/images/%s", project, name),
SizeGb: sizeGb,
}
}
func TestImage_IsWindows(t *testing.T) {
i := StubImage("foo", "foo-project", []string{"license-foo", "license-bar"}, 100)
assert.False(t, i.IsWindows())
i = StubImage("foo", "foo-project", []string{"license-foo", "windows-license"}, 100)
assert.True(t, i.IsWindows())
}

View File

@ -1,37 +1,40 @@
package googlecompute package googlecompute
import ( import (
"encoding/base64"
"fmt" "fmt"
) )
const StartupScriptStartLog string = "Packer startup script starting."
const StartupScriptDoneLog string = "Packer startup script done."
const StartupScriptKey string = "startup-script" const StartupScriptKey string = "startup-script"
const StartupScriptStatusKey string = "startup-script-status"
const StartupWrappedScriptKey string = "packer-wrapped-startup-script" const StartupWrappedScriptKey string = "packer-wrapped-startup-script"
// We have to encode StartupScriptDoneLog because we use it as a sentinel value to indicate const StartupScriptStatusDone string = "done"
// that the user-provided startup script is done. If we pass StartupScriptDoneLog as-is, it const StartupScriptStatusError string = "error"
// will be printed early in the instance console log (before the startup script even runs; const StartupScriptStatusNotDone string = "notdone"
// we print out instance creation metadata which contains this wrapper script).
var StartupScriptDoneLogBase64 string = base64.StdEncoding.EncodeToString([]byte(StartupScriptDoneLog))
var StartupScript string = fmt.Sprintf(`#!/bin/bash var StartupScriptLinux string = fmt.Sprintf(`#!/bin/bash
echo %s echo "Packer startup script starting."
RETVAL=0 RETVAL=0
BASEMETADATAURL=http://metadata/computeMetadata/v1/instance/
GetMetadata () { GetMetadata () {
echo "$(curl -f -H "Metadata-Flavor: Google" http://metadata/computeMetadata/v1/instance/attributes/$1 2> /dev/null)" echo "$(curl -f -H "Metadata-Flavor: Google" ${BASEMETADATAURL}/${1} 2> /dev/null)"
} }
STARTUPSCRIPT=$(GetMetadata %s) ZONE=$(GetMetadata zone | grep -oP "[^/]*$")
SetMetadata () {
gcloud compute instances add-metadata ${HOSTNAME} --metadata ${1}=${2} --zone ${ZONE}
}
STARTUPSCRIPT=$(GetMetadata attributes/%s)
STARTUPSCRIPTPATH=/packer-wrapped-startup-script STARTUPSCRIPTPATH=/packer-wrapped-startup-script
if [ -f "/var/log/startupscript.log" ]; then if [ -f "/var/log/startupscript.log" ]; then
STARTUPSCRIPTLOGPATH=/var/log/startupscript.log STARTUPSCRIPTLOGPATH=/var/log/startupscript.log
else else
STARTUPSCRIPTLOGPATH=/var/log/daemon.log STARTUPSCRIPTLOGPATH=/var/log/daemon.log
fi fi
STARTUPSCRIPTLOGDEST=$(GetMetadata startup-script-log-dest) STARTUPSCRIPTLOGDEST=$(GetMetadata attributes/startup-script-log-dest)
if [[ ! -z $STARTUPSCRIPT ]]; then if [[ ! -z $STARTUPSCRIPT ]]; then
echo "Executing user-provided startup script..." echo "Executing user-provided startup script..."
@ -48,6 +51,9 @@ if [[ ! -z $STARTUPSCRIPT ]]; then
rm ${STARTUPSCRIPTPATH} rm ${STARTUPSCRIPTPATH}
fi fi
echo $(echo %s | base64 --decode) echo "Packer startup script done."
SetMetadata %s %s
exit $RETVAL exit $RETVAL
`, StartupScriptStartLog, StartupWrappedScriptKey, StartupScriptDoneLogBase64) `, StartupWrappedScriptKey, StartupScriptStatusKey, StartupScriptStatusDone)
var StartupScriptWindows string = ""

View File

@ -13,14 +13,14 @@ type StepCheckExistingImage int
// Run executes the Packer build step that checks if the image already exists. // Run executes the Packer build step that checks if the image already exists.
func (s *StepCheckExistingImage) Run(state multistep.StateBag) multistep.StepAction { func (s *StepCheckExistingImage) Run(state multistep.StateBag) multistep.StepAction {
config := state.Get("config").(*Config) c := state.Get("config").(*Config)
driver := state.Get("driver").(Driver) d := state.Get("driver").(Driver)
ui := state.Get("ui").(packer.Ui) ui := state.Get("ui").(packer.Ui)
ui.Say("Checking image does not exist...") ui.Say("Checking image does not exist...")
exists := driver.ImageExists(config.ImageName) exists := d.ImageExists(c.ImageName)
if exists { if exists {
err := fmt.Errorf("Image %s already exists", config.ImageName) err := fmt.Errorf("Image %s already exists", c.ImageName)
state.Put("error", err) state.Put("error", err)
ui.Error(err.Error()) ui.Error(err.Error())
return multistep.ActionHalt return multistep.ActionHalt

View File

@ -24,7 +24,9 @@ func (s *StepCreateImage) Run(state multistep.StateBag) multistep.StepAction {
ui.Say("Creating image...") ui.Say("Creating image...")
imageCh, errCh := driver.CreateImage(config.ImageName, config.ImageDescription, config.ImageFamily, config.Zone, config.DiskName) imageCh, errCh := driver.CreateImage(
config.ImageName, config.ImageDescription, config.ImageFamily, config.Zone,
config.DiskName)
var err error var err error
select { select {
case err = <-errCh: case err = <-errCh:

View File

@ -5,6 +5,7 @@ import (
"testing" "testing"
"github.com/mitchellh/multistep" "github.com/mitchellh/multistep"
"github.com/stretchr/testify/assert"
) )
func TestStepCreateImage_impl(t *testing.T) { func TestStepCreateImage_impl(t *testing.T) {
@ -16,52 +17,35 @@ func TestStepCreateImage(t *testing.T) {
step := new(StepCreateImage) step := new(StepCreateImage)
defer step.Cleanup(state) defer step.Cleanup(state)
config := state.Get("config").(*Config) c := state.Get("config").(*Config)
driver := state.Get("driver").(*DriverMock) d := state.Get("driver").(*DriverMock)
driver.CreateImageProjectId = "createimage-project"
driver.CreateImageSizeGb = 100 // These are the values of the image the driver will return.
d.CreateImageResultLicenses = []string{"test-license"}
d.CreateImageResultProjectId = "test-project"
d.CreateImageResultSizeGb = 100
// run the step // run the step
if action := step.Run(state); action != multistep.ActionContinue { action := step.Run(state)
t.Fatalf("bad action: %#v", action) assert.Equal(t, action, multistep.ActionContinue, "Step did not pass.")
}
uncastImage, ok := state.GetOk("image") uncastImage, ok := state.GetOk("image")
if !ok { assert.True(t, ok, "State does not have resulting image.")
t.Fatal("should have image") image, ok := uncastImage.(*Image)
} assert.True(t, ok, "Image in state is not an Image.")
image, ok := uncastImage.(Image)
if !ok {
t.Fatal("image is not an Image")
}
// Verify created Image results. // Verify created Image results.
if image.Name != config.ImageName { assert.Equal(t, image.Licenses, d.CreateImageResultLicenses, "Created image licenses don't match the licenses returned by the driver.")
t.Fatalf("Created image name, %s, does not match config name, %s.", image.Name, config.ImageName) assert.Equal(t, image.Name, c.ImageName, "Created image does not match config name.")
} assert.Equal(t, image.ProjectId, d.CreateImageResultProjectId, "Created image project does not match driver project.")
if driver.CreateImageProjectId != image.ProjectId { assert.Equal(t, image.SizeGb, d.CreateImageResultSizeGb, "Created image size does not match the size returned by the driver.")
t.Fatalf("Created image project ID, %s, does not match driver project ID, %s.", image.ProjectId, driver.CreateImageProjectId)
}
if driver.CreateImageSizeGb != image.SizeGb {
t.Fatalf("Created image size, %d, does not match the expected test value, %d.", image.SizeGb, driver.CreateImageSizeGb)
}
// Verify proper args passed to driver.CreateImage. // Verify proper args passed to driver.CreateImage.
if driver.CreateImageName != config.ImageName { assert.Equal(t, d.CreateImageName, c.ImageName, "Incorrect image name passed to driver.")
t.Fatalf("bad: %#v", driver.CreateImageName) assert.Equal(t, d.CreateImageDesc, c.ImageDescription, "Incorrect image description passed to driver.")
} assert.Equal(t, d.CreateImageFamily, c.ImageFamily, "Incorrect image family passed to driver.")
if driver.CreateImageDesc != config.ImageDescription { assert.Equal(t, d.CreateImageZone, c.Zone, "Incorrect image zone passed to driver.")
t.Fatalf("bad: %#v", driver.CreateImageDesc) assert.Equal(t, d.CreateImageDisk, c.DiskName, "Incorrect disk passed to driver.")
}
if driver.CreateImageFamily != config.ImageFamily {
t.Fatalf("bad: %#v", driver.CreateImageFamily)
}
if driver.CreateImageZone != config.Zone {
t.Fatalf("bad: %#v", driver.CreateImageZone)
}
if driver.CreateImageDisk != config.DiskName {
t.Fatalf("bad: %#v", driver.CreateImageDisk)
}
} }
func TestStepCreateImage_errorOnChannel(t *testing.T) { func TestStepCreateImage_errorOnChannel(t *testing.T) {
@ -76,14 +60,10 @@ func TestStepCreateImage_errorOnChannel(t *testing.T) {
driver.CreateImageErrCh = errCh driver.CreateImageErrCh = errCh
// run the step // run the step
if action := step.Run(state); action != multistep.ActionHalt { action := step.Run(state)
t.Fatalf("bad action: %#v", action) assert.Equal(t, action, multistep.ActionHalt, "Step should not have passed.")
} _, ok := state.GetOk("error")
assert.True(t, ok, "State should have an error.")
if _, ok := state.GetOk("error"); !ok { _, ok = state.GetOk("image_name")
t.Fatal("should have error") assert.False(t, ok, "State should not have a resulting image.")
}
if _, ok := state.GetOk("image_name"); ok {
t.Fatal("should NOT have image")
}
} }

View File

@ -15,82 +15,97 @@ type StepCreateInstance struct {
Debug bool Debug bool
} }
func (config *Config) getImage() Image { func (c *Config) createInstanceMetadata(sourceImage *Image, sshPublicKey string) (map[string]string, error) {
project := config.ProjectId
if config.SourceImageProjectId != "" {
project = config.SourceImageProjectId
}
return Image{Name: config.SourceImage, ProjectId: project}
}
func (config *Config) getInstanceMetadata(sshPublicKey string) (map[string]string, error) {
instanceMetadata := make(map[string]string) instanceMetadata := make(map[string]string)
var err error var err error
// Copy metadata from config. // Copy metadata from config.
for k, v := range config.Metadata { for k, v := range c.Metadata {
instanceMetadata[k] = v instanceMetadata[k] = v
} }
// Merge any existing ssh keys with our public key. // Merge any existing ssh keys with our public key.
sshMetaKey := "sshKeys" sshMetaKey := "sshKeys"
sshKeys := fmt.Sprintf("%s:%s", config.Comm.SSHUsername, sshPublicKey) sshKeys := fmt.Sprintf("%s:%s", c.Comm.SSHUsername, sshPublicKey)
if confSshKeys, exists := instanceMetadata[sshMetaKey]; exists { if confSshKeys, exists := instanceMetadata[sshMetaKey]; exists {
sshKeys = fmt.Sprintf("%s\n%s", sshKeys, confSshKeys) sshKeys = fmt.Sprintf("%s\n%s", sshKeys, confSshKeys)
} }
instanceMetadata[sshMetaKey] = sshKeys instanceMetadata[sshMetaKey] = sshKeys
// Wrap any startup script with our own startup script. // Wrap any startup script with our own startup script.
if config.StartupScriptFile != "" { if c.StartupScriptFile != "" {
var content []byte var content []byte
content, err = ioutil.ReadFile(config.StartupScriptFile) content, err = ioutil.ReadFile(c.StartupScriptFile)
instanceMetadata[StartupWrappedScriptKey] = string(content) instanceMetadata[StartupWrappedScriptKey] = string(content)
} else if wrappedStartupScript, exists := instanceMetadata[StartupScriptKey]; exists { } else if wrappedStartupScript, exists := instanceMetadata[StartupScriptKey]; exists {
instanceMetadata[StartupWrappedScriptKey] = wrappedStartupScript instanceMetadata[StartupWrappedScriptKey] = wrappedStartupScript
} }
instanceMetadata[StartupScriptKey] = StartupScript if sourceImage.IsWindows() {
// Windows startup script support is not yet implemented.
// Mark the startup script as done.
instanceMetadata[StartupScriptKey] = StartupScriptWindows
instanceMetadata[StartupScriptStatusKey] = StartupScriptStatusDone
} else {
instanceMetadata[StartupScriptKey] = StartupScriptLinux
instanceMetadata[StartupScriptStatusKey] = StartupScriptStatusNotDone
}
return instanceMetadata, err return instanceMetadata, err
} }
func getImage(c *Config, d Driver) (*Image, error) {
if c.SourceImageProjectId == "" {
return d.GetImage(c.SourceImage)
} else {
return d.GetImageFromProject(c.SourceImageProjectId, c.SourceImage)
}
}
// Run executes the Packer build step that creates a GCE instance. // Run executes the Packer build step that creates a GCE instance.
func (s *StepCreateInstance) Run(state multistep.StateBag) multistep.StepAction { func (s *StepCreateInstance) Run(state multistep.StateBag) multistep.StepAction {
config := state.Get("config").(*Config) c := state.Get("config").(*Config)
driver := state.Get("driver").(Driver) d := state.Get("driver").(Driver)
sshPublicKey := state.Get("ssh_public_key").(string) sshPublicKey := state.Get("ssh_public_key").(string)
ui := state.Get("ui").(packer.Ui) ui := state.Get("ui").(packer.Ui)
sourceImage, err := getImage(c, d)
if err != nil {
err := fmt.Errorf("Error getting source image for instance creation: %s", err)
state.Put("error", err)
ui.Error(err.Error())
return multistep.ActionHalt
}
ui.Say("Creating instance...") ui.Say("Creating instance...")
name := config.InstanceName name := c.InstanceName
var errCh <-chan error var errCh <-chan error
var err error
var metadata map[string]string var metadata map[string]string
metadata, err = config.getInstanceMetadata(sshPublicKey) metadata, err = c.createInstanceMetadata(sourceImage, sshPublicKey)
errCh, err = driver.RunInstance(&InstanceConfig{ errCh, err = d.RunInstance(&InstanceConfig{
Address: config.Address, Address: c.Address,
Description: "New instance created by Packer", Description: "New instance created by Packer",
DiskSizeGb: config.DiskSizeGb, DiskSizeGb: c.DiskSizeGb,
DiskType: config.DiskType, DiskType: c.DiskType,
Image: config.getImage(), Image: sourceImage,
MachineType: config.MachineType, MachineType: c.MachineType,
Metadata: metadata, Metadata: metadata,
Name: name, Name: name,
Network: config.Network, Network: c.Network,
OmitExternalIP: config.OmitExternalIP, OmitExternalIP: c.OmitExternalIP,
Preemptible: config.Preemptible, Preemptible: c.Preemptible,
Region: config.Region, Region: c.Region,
ServiceAccountEmail: config.Account.ClientEmail, ServiceAccountEmail: c.Account.ClientEmail,
Subnetwork: config.Subnetwork, Subnetwork: c.Subnetwork,
Tags: config.Tags, Tags: c.Tags,
Zone: config.Zone, Zone: c.Zone,
}) })
if err == nil { if err == nil {
ui.Message("Waiting for creation operation to complete...") ui.Message("Waiting for creation operation to complete...")
select { select {
case err = <-errCh: case err = <-errCh:
case <-time.After(config.stateTimeout): case <-time.After(c.stateTimeout):
err = errors.New("time out while waiting for instance to create") err = errors.New("time out while waiting for instance to create")
} }
} }
@ -106,7 +121,7 @@ func (s *StepCreateInstance) Run(state multistep.StateBag) multistep.StepAction
if s.Debug { if s.Debug {
if name != "" { if name != "" {
ui.Message(fmt.Sprintf("Instance: %s started in %s", name, config.Zone)) ui.Message(fmt.Sprintf("Instance: %s started in %s", name, c.Zone))
} }
} }

View File

@ -2,9 +2,11 @@ package googlecompute
import ( import (
"errors" "errors"
"github.com/mitchellh/multistep"
"testing" "testing"
"time" "time"
"github.com/mitchellh/multistep"
"github.com/stretchr/testify/assert"
) )
func TestStepCreateInstance_impl(t *testing.T) { func TestStepCreateInstance_impl(t *testing.T) {
@ -18,36 +20,25 @@ func TestStepCreateInstance(t *testing.T) {
state.Put("ssh_public_key", "key") state.Put("ssh_public_key", "key")
config := state.Get("config").(*Config) c := state.Get("config").(*Config)
driver := state.Get("driver").(*DriverMock) d := state.Get("driver").(*DriverMock)
d.GetImageResult = StubImage("test-image", "test-project", []string{}, 100)
// run the step // run the step
if action := step.Run(state); action != multistep.ActionContinue { assert.Equal(t, step.Run(state), multistep.ActionContinue, "Step should have passed and continued.")
t.Fatalf("bad action: %#v", action)
}
// Verify state // Verify state
nameRaw, ok := state.GetOk("instance_name") nameRaw, ok := state.GetOk("instance_name")
if !ok { assert.True(t, ok, "State should have an instance name.")
t.Fatal("should have instance name")
}
// cleanup // cleanup
step.Cleanup(state) step.Cleanup(state)
if driver.DeleteInstanceName != nameRaw.(string) { // Check args passed to the driver.
t.Fatal("should've deleted instance") assert.Equal(t, d.DeleteInstanceName, nameRaw.(string), "Incorrect instance name passed to driver.")
} assert.Equal(t, d.DeleteInstanceZone, c.Zone, "Incorrect instance zone passed to driver.")
if driver.DeleteInstanceZone != config.Zone { assert.Equal(t, d.DeleteDiskName, c.InstanceName, "Incorrect disk name passed to driver.")
t.Fatalf("bad instance zone: %#v", driver.DeleteInstanceZone) assert.Equal(t, d.DeleteDiskZone, c.Zone, "Incorrect disk zone passed to driver.")
}
if driver.DeleteDiskName != config.InstanceName {
t.Fatal("should've deleted disk")
}
if driver.DeleteDiskZone != config.Zone {
t.Fatalf("bad disk zone: %#v", driver.DeleteDiskZone)
}
} }
func TestStepCreateInstance_error(t *testing.T) { func TestStepCreateInstance_error(t *testing.T) {
@ -57,21 +48,18 @@ func TestStepCreateInstance_error(t *testing.T) {
state.Put("ssh_public_key", "key") state.Put("ssh_public_key", "key")
driver := state.Get("driver").(*DriverMock) d := state.Get("driver").(*DriverMock)
driver.RunInstanceErr = errors.New("error") d.RunInstanceErr = errors.New("error")
d.GetImageResult = StubImage("test-image", "test-project", []string{}, 100)
// run the step // run the step
if action := step.Run(state); action != multistep.ActionHalt { assert.Equal(t, step.Run(state), multistep.ActionHalt, "Step should have failed and halted.")
t.Fatalf("bad action: %#v", action)
}
// Verify state // Verify state
if _, ok := state.GetOk("error"); !ok { _, ok := state.GetOk("error")
t.Fatal("should have error") assert.True(t, ok, "State should have an error.")
} _, ok = state.GetOk("instance_name")
if _, ok := state.GetOk("instance_name"); ok { assert.False(t, ok, "State should not have an instance name.")
t.Fatal("should NOT have instance name")
}
} }
func TestStepCreateInstance_errorOnChannel(t *testing.T) { func TestStepCreateInstance_errorOnChannel(t *testing.T) {
@ -79,26 +67,23 @@ func TestStepCreateInstance_errorOnChannel(t *testing.T) {
step := new(StepCreateInstance) step := new(StepCreateInstance)
defer step.Cleanup(state) defer step.Cleanup(state)
state.Put("ssh_public_key", "key")
errCh := make(chan error, 1) errCh := make(chan error, 1)
errCh <- errors.New("error") errCh <- errors.New("error")
state.Put("ssh_public_key", "key") d := state.Get("driver").(*DriverMock)
d.RunInstanceErrCh = errCh
driver := state.Get("driver").(*DriverMock) d.GetImageResult = StubImage("test-image", "test-project", []string{}, 100)
driver.RunInstanceErrCh = errCh
// run the step // run the step
if action := step.Run(state); action != multistep.ActionHalt { assert.Equal(t, step.Run(state), multistep.ActionHalt, "Step should have failed and halted.")
t.Fatalf("bad action: %#v", action)
}
// Verify state // Verify state
if _, ok := state.GetOk("error"); !ok { _, ok := state.GetOk("error")
t.Fatal("should have error") assert.True(t, ok, "State should have an error.")
} _, ok = state.GetOk("instance_name")
if _, ok := state.GetOk("instance_name"); ok { assert.False(t, ok, "State should not have an instance name.")
t.Fatal("should NOT have instance name")
}
} }
func TestStepCreateInstance_errorTimeout(t *testing.T) { func TestStepCreateInstance_errorTimeout(t *testing.T) {
@ -106,30 +91,27 @@ func TestStepCreateInstance_errorTimeout(t *testing.T) {
step := new(StepCreateInstance) step := new(StepCreateInstance)
defer step.Cleanup(state) defer step.Cleanup(state)
state.Put("ssh_public_key", "key")
errCh := make(chan error, 1) errCh := make(chan error, 1)
go func() { go func() {
<-time.After(10 * time.Millisecond) <-time.After(10 * time.Millisecond)
errCh <- nil errCh <- nil
}() }()
state.Put("ssh_public_key", "key")
config := state.Get("config").(*Config) config := state.Get("config").(*Config)
config.stateTimeout = 1 * time.Microsecond config.stateTimeout = 1 * time.Microsecond
driver := state.Get("driver").(*DriverMock) d := state.Get("driver").(*DriverMock)
driver.RunInstanceErrCh = errCh d.RunInstanceErrCh = errCh
d.GetImageResult = StubImage("test-image", "test-project", []string{}, 100)
// run the step // run the step
if action := step.Run(state); action != multistep.ActionHalt { assert.Equal(t, step.Run(state), multistep.ActionHalt, "Step should have failed and halted.")
t.Fatalf("bad action: %#v", action)
}
// Verify state // Verify state
if _, ok := state.GetOk("error"); !ok { _, ok := state.GetOk("error")
t.Fatal("should have error") assert.True(t, ok, "State should have an error.")
} _, ok = state.GetOk("instance_name")
if _, ok := state.GetOk("instance_name"); ok { assert.False(t, ok, "State should not have an instance name.")
t.Fatal("should NOT have instance name")
}
} }

View File

@ -71,8 +71,8 @@ func (s *StepTeardownInstance) Cleanup(state multistep.StateBag) {
if err != nil { if err != nil {
ui.Error(fmt.Sprintf( ui.Error(fmt.Sprintf(
"Error deleting disk. Please delete it manually.\n\n"+ "Error deleting disk. Please delete it manually.\n\n"+
"DiskName: %s\n" + "DiskName: %s\n"+
"Zone: %s\n" + "Zone: %s\n"+
"Error: %s", config.DiskName, config.Zone, err)) "Error: %s", config.DiskName, config.Zone, err))
} }

View File

@ -2,9 +2,10 @@ package googlecompute
import ( import (
"bytes" "bytes"
"testing"
"github.com/mitchellh/multistep" "github.com/mitchellh/multistep"
"github.com/mitchellh/packer/packer" "github.com/mitchellh/packer/packer"
"testing"
) )
func testState(t *testing.T) multistep.StateBag { func testState(t *testing.T) multistep.StateBag {

View File

@ -1,8 +1,8 @@
package googlecompute package googlecompute
import( import (
"errors"
"fmt" "fmt"
"strings"
"github.com/mitchellh/multistep" "github.com/mitchellh/multistep"
"github.com/mitchellh/packer/packer" "github.com/mitchellh/packer/packer"
@ -10,25 +10,32 @@ import(
type StepWaitInstanceStartup int type StepWaitInstanceStartup int
// Run reads the instance serial port output and looks for the log entry indicating the startup script finished. // Run reads the instance metadata and looks for the log entry
// indicating the startup script finished.
func (s *StepWaitInstanceStartup) Run(state multistep.StateBag) multistep.StepAction { func (s *StepWaitInstanceStartup) Run(state multistep.StateBag) multistep.StepAction {
config := state.Get("config").(*Config) config := state.Get("config").(*Config)
driver := state.Get("driver").(Driver) driver := state.Get("driver").(Driver)
ui := state.Get("ui").(packer.Ui) ui := state.Get("ui").(packer.Ui)
instanceName := state.Get("instance_name").(string) instanceName := state.Get("instance_name").(string)
ui.Say("Waiting for any running startup script to finish...") ui.Say("Waiting for any running startup script to finish...")
// Keep checking the serial port output to see if the startup script is done. // Keep checking the serial port output to see if the startup script is done.
err := Retry(10, 60, 0, func() (bool, error) { err := Retry(10, 60, 0, func() (bool, error) {
output, err := driver.GetSerialPortOutput(config.Zone, instanceName) status, err := driver.GetInstanceMetadata(config.Zone,
instanceName, StartupScriptStatusKey)
if err != nil { if err != nil {
err := fmt.Errorf("Error getting serial port output: %s", err) err := fmt.Errorf("Error getting startup script status: %s", err)
return false, err return false, err
} }
done := strings.Contains(output, StartupScriptDoneLog) if status == StartupScriptStatusError {
err = errors.New("Startup script error.")
return false, err
}
done := status == StartupScriptStatusDone
if !done { if !done {
ui.Say("Startup script not finished yet. Waiting...") ui.Say("Startup script not finished yet. Waiting...")
} }

View File

@ -2,37 +2,29 @@ package googlecompute
import ( import (
"github.com/mitchellh/multistep" "github.com/mitchellh/multistep"
"github.com/stretchr/testify/assert"
"testing" "testing"
) )
func TestStepWaitInstanceStartup(t *testing.T) { func TestStepWaitInstanceStartup(t *testing.T) {
state := testState(t) state := testState(t)
step := new(StepWaitInstanceStartup) step := new(StepWaitInstanceStartup)
config := state.Get("config").(*Config) c := state.Get("config").(*Config)
driver := state.Get("driver").(*DriverMock) d := state.Get("driver").(*DriverMock)
testZone := "test-zone" testZone := "test-zone"
testInstanceName := "test-instance-name" testInstanceName := "test-instance-name"
config.Zone = testZone c.Zone = testZone
state.Put("instance_name", testInstanceName) state.Put("instance_name", testInstanceName)
// The done log triggers step completion.
driver.GetSerialPortOutputResult = StartupScriptDoneLog // This step stops when it gets Done back from the metadata.
d.GetInstanceMetadataResult = StartupScriptStatusDone
// Run the step. // Run the step.
if action := step.Run(state); action != multistep.ActionContinue { assert.Equal(t, step.Run(state), multistep.ActionContinue, "Step should have passed and continued.")
t.Fatalf("StepWaitInstanceStartup did not return a Continue action: %#v", action)
} // Check that GetInstanceMetadata was called properly.
assert.Equal(t, d.GetInstanceMetadataZone, testZone, "Incorrect zone passed to GetInstanceMetadata.")
// Check that GetSerialPortOutput was called properly. assert.Equal(t, d.GetInstanceMetadataName, testInstanceName, "Incorrect instance name passed to GetInstanceMetadata.")
if driver.GetSerialPortOutputZone != testZone { }
t.Fatalf(
"GetSerialPortOutput wrong zone. Expected: %s, Actual: %s", driver.GetSerialPortOutputZone,
testZone)
}
if driver.GetSerialPortOutputName != testInstanceName {
t.Fatalf(
"GetSerialPortOutput wrong instance name. Expected: %s, Actual: %s", driver.GetSerialPortOutputName,
testInstanceName)
}
}

View File

@ -1,18 +0,0 @@
package common
import (
"testing"
)
func TestFloppyConfigPrepare(t *testing.T) {
c := new(FloppyConfig)
errs := c.Prepare(testConfigTemplate(t))
if len(errs) > 0 {
t.Fatalf("err: %#v", errs)
}
if len(c.FloppyFiles) > 0 {
t.Fatal("should not have floppy files")
}
}

View File

@ -25,7 +25,7 @@ type Config struct {
common.PackerConfig `mapstructure:",squash"` common.PackerConfig `mapstructure:",squash"`
common.HTTPConfig `mapstructure:",squash"` common.HTTPConfig `mapstructure:",squash"`
common.ISOConfig `mapstructure:",squash"` common.ISOConfig `mapstructure:",squash"`
parallelscommon.FloppyConfig `mapstructure:",squash"` common.FloppyConfig `mapstructure:",squash"`
parallelscommon.OutputConfig `mapstructure:",squash"` parallelscommon.OutputConfig `mapstructure:",squash"`
parallelscommon.PrlctlConfig `mapstructure:",squash"` parallelscommon.PrlctlConfig `mapstructure:",squash"`
parallelscommon.PrlctlPostConfig `mapstructure:",squash"` parallelscommon.PrlctlPostConfig `mapstructure:",squash"`

View File

@ -1,8 +1,11 @@
package iso package iso
import ( import (
"github.com/mitchellh/packer/packer" "fmt"
"reflect"
"testing" "testing"
"github.com/mitchellh/packer/packer"
) )
func testConfig() map[string]interface{} { func testConfig() map[string]interface{} {
@ -46,6 +49,55 @@ func TestBuilderPrepare_Defaults(t *testing.T) {
} }
} }
func TestBuilderPrepare_FloppyFiles(t *testing.T) {
var b Builder
config := testConfig()
delete(config, "floppy_files")
warns, err := b.Prepare(config)
if len(warns) > 0 {
t.Fatalf("bad: %#v", warns)
}
if err != nil {
t.Fatalf("bad err: %s", err)
}
if len(b.config.FloppyFiles) != 0 {
t.Fatalf("bad: %#v", b.config.FloppyFiles)
}
floppies_path := "../../../common/test-fixtures/floppies"
config["floppy_files"] = []string{fmt.Sprintf("%s/bar.bat", floppies_path), fmt.Sprintf("%s/foo.ps1", floppies_path)}
b = Builder{}
warns, err = b.Prepare(config)
if len(warns) > 0 {
t.Fatalf("bad: %#v", warns)
}
if err != nil {
t.Fatalf("should not have error: %s", err)
}
expected := []string{fmt.Sprintf("%s/bar.bat", floppies_path), fmt.Sprintf("%s/foo.ps1", floppies_path)}
if !reflect.DeepEqual(b.config.FloppyFiles, expected) {
t.Fatalf("bad: %#v", b.config.FloppyFiles)
}
}
func TestBuilderPrepare_InvalidFloppies(t *testing.T) {
var b Builder
config := testConfig()
config["floppy_files"] = []string{"nonexistant.bat", "nonexistant.ps1"}
b = Builder{}
_, errs := b.Prepare(config)
if errs == nil {
t.Fatalf("Non existant floppies should trigger multierror")
}
if len(errs.(*packer.MultiError).Errors) != 2 {
t.Fatalf("Multierror should work and report 2 errors")
}
}
func TestBuilderPrepare_DiskSize(t *testing.T) { func TestBuilderPrepare_DiskSize(t *testing.T) {
var b Builder var b Builder
config := testConfig() config := testConfig()

View File

@ -14,7 +14,7 @@ import (
// Config is the configuration structure for the builder. // Config is the configuration structure for the builder.
type Config struct { type Config struct {
common.PackerConfig `mapstructure:",squash"` common.PackerConfig `mapstructure:",squash"`
parallelscommon.FloppyConfig `mapstructure:",squash"` common.FloppyConfig `mapstructure:",squash"`
parallelscommon.OutputConfig `mapstructure:",squash"` parallelscommon.OutputConfig `mapstructure:",squash"`
parallelscommon.PrlctlConfig `mapstructure:",squash"` parallelscommon.PrlctlConfig `mapstructure:",squash"`
parallelscommon.PrlctlPostConfig `mapstructure:",squash"` parallelscommon.PrlctlPostConfig `mapstructure:",squash"`

View File

@ -1,9 +1,12 @@
package pvm package pvm
import ( import (
"fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"testing" "testing"
"github.com/mitchellh/packer/packer"
) )
func testConfig(t *testing.T) map[string]interface{} { func testConfig(t *testing.T) map[string]interface{} {
@ -11,6 +14,7 @@ func testConfig(t *testing.T) map[string]interface{} {
"ssh_username": "foo", "ssh_username": "foo",
"shutdown_command": "foo", "shutdown_command": "foo",
"parallels_tools_flavor": "lin", "parallels_tools_flavor": "lin",
"source_path": "config_test.go",
} }
} }
@ -68,6 +72,29 @@ func TestNewConfig_sourcePath(t *testing.T) {
testConfigOk(t, warns, errs) testConfigOk(t, warns, errs)
} }
func TestNewConfig_FloppyFiles(t *testing.T) {
c := testConfig(t)
floppies_path := "../../../common/test-fixtures/floppies"
c["floppy_files"] = []string{fmt.Sprintf("%s/bar.bat", floppies_path), fmt.Sprintf("%s/foo.ps1", floppies_path)}
_, _, err := NewConfig(c)
if err != nil {
t.Fatalf("should not have error: %s", err)
}
}
func TestNewConfig_InvalidFloppies(t *testing.T) {
c := testConfig(t)
c["floppy_files"] = []string{"nonexistant.bat", "nonexistant.ps1"}
_, _, errs := NewConfig(c)
if errs == nil {
t.Fatalf("Non existant floppies should trigger multierror")
}
if len(errs.(*packer.MultiError).Errors) != 2 {
t.Fatalf("Multierror should work and report 2 errors")
}
}
func TestNewConfig_shutdown_timeout(t *testing.T) { func TestNewConfig_shutdown_timeout(t *testing.T) {
c := testConfig(t) c := testConfig(t)
tf := getTempFile(t) tf := getTempFile(t)

View File

@ -8,11 +8,11 @@ import (
func testConfig() map[string]interface{} { func testConfig() map[string]interface{} {
return map[string]interface{}{ return map[string]interface{}{
"image": "Ubuntu-16.04", "image": "Ubuntu-16.04",
"password": "password", "password": "password",
"username": "username", "username": "username",
"snapshot_name": "packer", "snapshot_name": "packer",
"type": "profitbricks", "type": "profitbricks",
} }
} }
@ -53,4 +53,4 @@ func TestBuilderPrepare_InvalidKey(t *testing.T) {
if err == nil { if err == nil {
t.Fatal("should have error") t.Fatal("should have error")
} }
} }

View File

@ -174,7 +174,7 @@ func (d *stepCreateServer) getImageId(imageName string, c *Config) string {
if c.DiskType == "SSD" { if c.DiskType == "SSD" {
diskType = "HDD" diskType = "HDD"
} }
if imgName != "" && strings.Contains(strings.ToLower(imgName), strings.ToLower(imageName)) && images.Items[i].Properties.ImageType == diskType && images.Items[i].Properties.Location == c.Region { if imgName != "" && strings.Contains(strings.ToLower(imgName), strings.ToLower(imageName)) && images.Items[i].Properties.ImageType == diskType && images.Items[i].Properties.Location == c.Region && images.Items[i].Properties.Public == true {
return images.Items[i].Id return images.Items[i].Id
} }
} }

View File

@ -82,6 +82,7 @@ type Config struct {
common.HTTPConfig `mapstructure:",squash"` common.HTTPConfig `mapstructure:",squash"`
common.ISOConfig `mapstructure:",squash"` common.ISOConfig `mapstructure:",squash"`
Comm communicator.Config `mapstructure:",squash"` Comm communicator.Config `mapstructure:",squash"`
common.FloppyConfig `mapstructure:",squash"`
ISOSkipCache bool `mapstructure:"iso_skip_cache"` ISOSkipCache bool `mapstructure:"iso_skip_cache"`
Accelerator string `mapstructure:"accelerator"` Accelerator string `mapstructure:"accelerator"`
@ -139,6 +140,9 @@ func (b *Builder) Prepare(raws ...interface{}) ([]string, error) {
return nil, err return nil, err
} }
var errs *packer.MultiError
warnings := make([]string, 0)
if b.config.DiskSize == 0 { if b.config.DiskSize == 0 {
b.config.DiskSize = 40000 b.config.DiskSize = 40000
} }
@ -215,9 +219,7 @@ func (b *Builder) Prepare(raws ...interface{}) ([]string, error) {
b.config.Format = "qcow2" b.config.Format = "qcow2"
} }
if b.config.FloppyFiles == nil { errs = packer.MultiErrorAppend(errs, b.config.FloppyConfig.Prepare(&b.config.ctx)...)
b.config.FloppyFiles = make([]string, 0)
}
if b.config.NetDevice == "" { if b.config.NetDevice == "" {
b.config.NetDevice = "virtio-net" b.config.NetDevice = "virtio-net"
@ -232,9 +234,6 @@ func (b *Builder) Prepare(raws ...interface{}) ([]string, error) {
b.config.Comm.SSHTimeout = b.config.SSHWaitTimeout b.config.Comm.SSHTimeout = b.config.SSHWaitTimeout
} }
var errs *packer.MultiError
warnings := make([]string, 0)
if b.config.ISOSkipCache { if b.config.ISOSkipCache {
b.config.ISOChecksumType = "none" b.config.ISOChecksumType = "none"
} }

View File

@ -1,11 +1,13 @@
package qemu package qemu
import ( import (
"github.com/mitchellh/packer/packer" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"reflect" "reflect"
"testing" "testing"
"github.com/mitchellh/packer/packer"
) )
var testPem = ` var testPem = `
@ -262,6 +264,55 @@ func TestBuilderPrepare_Format(t *testing.T) {
} }
} }
func TestBuilderPrepare_FloppyFiles(t *testing.T) {
var b Builder
config := testConfig()
delete(config, "floppy_files")
warns, err := b.Prepare(config)
if len(warns) > 0 {
t.Fatalf("bad: %#v", warns)
}
if err != nil {
t.Fatalf("bad err: %s", err)
}
if len(b.config.FloppyFiles) != 0 {
t.Fatalf("bad: %#v", b.config.FloppyFiles)
}
floppies_path := "../../common/test-fixtures/floppies"
config["floppy_files"] = []string{fmt.Sprintf("%s/bar.bat", floppies_path), fmt.Sprintf("%s/foo.ps1", floppies_path)}
b = Builder{}
warns, err = b.Prepare(config)
if len(warns) > 0 {
t.Fatalf("bad: %#v", warns)
}
if err != nil {
t.Fatalf("should not have error: %s", err)
}
expected := []string{fmt.Sprintf("%s/bar.bat", floppies_path), fmt.Sprintf("%s/foo.ps1", floppies_path)}
if !reflect.DeepEqual(b.config.FloppyFiles, expected) {
t.Fatalf("bad: %#v", b.config.FloppyFiles)
}
}
func TestBuilderPrepare_InvalidFloppies(t *testing.T) {
var b Builder
config := testConfig()
config["floppy_files"] = []string{"nonexistant.bat", "nonexistant.ps1"}
b = Builder{}
_, errs := b.Prepare(config)
if errs == nil {
t.Fatalf("Non existant floppies should trigger multierror")
}
if len(errs.(*packer.MultiError).Errors) != 2 {
t.Fatalf("Multierror should work and report 2 errors")
}
}
func TestBuilderPrepare_InvalidKey(t *testing.T) { func TestBuilderPrepare_InvalidKey(t *testing.T) {
var b Builder var b Builder
config := testConfig() config := testConfig()

View File

@ -1,19 +0,0 @@
package common
import (
"github.com/mitchellh/packer/template/interpolate"
)
// FloppyConfig is configuration related to created floppy disks and attaching
// them to a VirtualBox machine.
type FloppyConfig struct {
FloppyFiles []string `mapstructure:"floppy_files"`
}
func (c *FloppyConfig) Prepare(ctx *interpolate.Context) []error {
if c.FloppyFiles == nil {
c.FloppyFiles = make([]string, 0)
}
return nil
}

View File

@ -1,18 +0,0 @@
package common
import (
"testing"
)
func TestFloppyConfigPrepare(t *testing.T) {
c := new(FloppyConfig)
errs := c.Prepare(testConfigTemplate(t))
if len(errs) > 0 {
t.Fatalf("err: %#v", errs)
}
if len(c.FloppyFiles) > 0 {
t.Fatal("should not have floppy files")
}
}

View File

@ -26,9 +26,9 @@ type Config struct {
common.PackerConfig `mapstructure:",squash"` common.PackerConfig `mapstructure:",squash"`
common.HTTPConfig `mapstructure:",squash"` common.HTTPConfig `mapstructure:",squash"`
common.ISOConfig `mapstructure:",squash"` common.ISOConfig `mapstructure:",squash"`
common.FloppyConfig `mapstructure:",squash"`
vboxcommon.ExportConfig `mapstructure:",squash"` vboxcommon.ExportConfig `mapstructure:",squash"`
vboxcommon.ExportOpts `mapstructure:",squash"` vboxcommon.ExportOpts `mapstructure:",squash"`
vboxcommon.FloppyConfig `mapstructure:",squash"`
vboxcommon.OutputConfig `mapstructure:",squash"` vboxcommon.OutputConfig `mapstructure:",squash"`
vboxcommon.RunConfig `mapstructure:",squash"` vboxcommon.RunConfig `mapstructure:",squash"`
vboxcommon.ShutdownConfig `mapstructure:",squash"` vboxcommon.ShutdownConfig `mapstructure:",squash"`

View File

@ -1,9 +1,12 @@
package iso package iso
import ( import (
"fmt"
"reflect"
"testing"
"github.com/mitchellh/packer/builder/virtualbox/common" "github.com/mitchellh/packer/builder/virtualbox/common"
"github.com/mitchellh/packer/packer" "github.com/mitchellh/packer/packer"
"testing"
) )
func testConfig() map[string]interface{} { func testConfig() map[string]interface{} {
@ -86,6 +89,55 @@ func TestBuilderPrepare_DiskSize(t *testing.T) {
} }
} }
func TestBuilderPrepare_FloppyFiles(t *testing.T) {
var b Builder
config := testConfig()
delete(config, "floppy_files")
warns, err := b.Prepare(config)
if len(warns) > 0 {
t.Fatalf("bad: %#v", warns)
}
if err != nil {
t.Fatalf("bad err: %s", err)
}
if len(b.config.FloppyFiles) != 0 {
t.Fatalf("bad: %#v", b.config.FloppyFiles)
}
floppies_path := "../../../common/test-fixtures/floppies"
config["floppy_files"] = []string{fmt.Sprintf("%s/bar.bat", floppies_path), fmt.Sprintf("%s/foo.ps1", floppies_path)}
b = Builder{}
warns, err = b.Prepare(config)
if len(warns) > 0 {
t.Fatalf("bad: %#v", warns)
}
if err != nil {
t.Fatalf("should not have error: %s", err)
}
expected := []string{fmt.Sprintf("%s/bar.bat", floppies_path), fmt.Sprintf("%s/foo.ps1", floppies_path)}
if !reflect.DeepEqual(b.config.FloppyFiles, expected) {
t.Fatalf("bad: %#v", b.config.FloppyFiles)
}
}
func TestBuilderPrepare_InvalidFloppies(t *testing.T) {
var b Builder
config := testConfig()
config["floppy_files"] = []string{"nonexistant.bat", "nonexistant.ps1"}
b = Builder{}
_, errs := b.Prepare(config)
if errs == nil {
t.Fatalf("Non existant floppies should trigger multierror")
}
if len(errs.(*packer.MultiError).Errors) != 2 {
t.Fatalf("Multierror should work and report 2 errors")
}
}
func TestBuilderPrepare_GuestAdditionsMode(t *testing.T) { func TestBuilderPrepare_GuestAdditionsMode(t *testing.T) {
var b Builder var b Builder
config := testConfig() config := testConfig()

View File

@ -16,9 +16,9 @@ import (
type Config struct { type Config struct {
common.PackerConfig `mapstructure:",squash"` common.PackerConfig `mapstructure:",squash"`
common.HTTPConfig `mapstructure:",squash"` common.HTTPConfig `mapstructure:",squash"`
common.FloppyConfig `mapstructure:",squash"`
vboxcommon.ExportConfig `mapstructure:",squash"` vboxcommon.ExportConfig `mapstructure:",squash"`
vboxcommon.ExportOpts `mapstructure:",squash"` vboxcommon.ExportOpts `mapstructure:",squash"`
vboxcommon.FloppyConfig `mapstructure:",squash"`
vboxcommon.OutputConfig `mapstructure:",squash"` vboxcommon.OutputConfig `mapstructure:",squash"`
vboxcommon.RunConfig `mapstructure:",squash"` vboxcommon.RunConfig `mapstructure:",squash"`
vboxcommon.SSHConfig `mapstructure:",squash"` vboxcommon.SSHConfig `mapstructure:",squash"`

View File

@ -1,15 +1,19 @@
package ovf package ovf
import ( import (
"fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"testing" "testing"
"github.com/mitchellh/packer/packer"
) )
func testConfig(t *testing.T) map[string]interface{} { func testConfig(t *testing.T) map[string]interface{} {
return map[string]interface{}{ return map[string]interface{}{
"ssh_username": "foo", "ssh_username": "foo",
"shutdown_command": "foo", "shutdown_command": "foo",
"source_path": "config_test.go",
} }
} }
@ -44,6 +48,29 @@ func testConfigOk(t *testing.T, warns []string, err error) {
} }
} }
func TestNewConfig_FloppyFiles(t *testing.T) {
c := testConfig(t)
floppies_path := "../../../common/test-fixtures/floppies"
c["floppy_files"] = []string{fmt.Sprintf("%s/bar.bat", floppies_path), fmt.Sprintf("%s/foo.ps1", floppies_path)}
_, _, err := NewConfig(c)
if err != nil {
t.Fatalf("should not have error: %s", err)
}
}
func TestNewConfig_InvalidFloppies(t *testing.T) {
c := testConfig(t)
c["floppy_files"] = []string{"nonexistant.bat", "nonexistant.ps1"}
_, _, errs := NewConfig(c)
if errs == nil {
t.Fatalf("Non existant floppies should trigger multierror")
}
if len(errs.(*packer.MultiError).Errors) != 2 {
t.Fatalf("Multierror should work and report 2 errors")
}
}
func TestNewConfig_sourcePath(t *testing.T) { func TestNewConfig_sourcePath(t *testing.T) {
// Bad // Bad
c := testConfig(t) c := testConfig(t)

View File

@ -46,13 +46,13 @@ func EncodeVMX(contents map[string]string) string {
// a list of VMX key fragments that the value must not be quoted // a list of VMX key fragments that the value must not be quoted
// fragments are used to cover multliples (i.e. multiple disks) // fragments are used to cover multliples (i.e. multiple disks)
// keys are still lowercase at this point, use lower fragments // keys are still lowercase at this point, use lower fragments
noQuotes := []string { noQuotes := []string{
".virtualssd", ".virtualssd",
} }
// a list of VMX key fragments that are case sensitive // a list of VMX key fragments that are case sensitive
// fragments are used to cover multliples (i.e. multiple disks) // fragments are used to cover multliples (i.e. multiple disks)
caseSensitive := []string { caseSensitive := []string{
".virtualSSD", ".virtualSSD",
} }
@ -63,7 +63,7 @@ func EncodeVMX(contents map[string]string) string {
for _, q := range noQuotes { for _, q := range noQuotes {
if strings.Contains(k, q) { if strings.Contains(k, q) {
pat = "%s = %s\n" pat = "%s = %s\n"
break; break
} }
} }
key := k key := k

View File

@ -29,8 +29,8 @@ scsi0:0.virtualSSD = 1
func TestEncodeVMX(t *testing.T) { func TestEncodeVMX(t *testing.T) {
contents := map[string]string{ contents := map[string]string{
".encoding": "UTF-8", ".encoding": "UTF-8",
"config.version": "8", "config.version": "8",
"scsi0:0.virtualssd": "1", "scsi0:0.virtualssd": "1",
} }

View File

@ -28,6 +28,7 @@ type Config struct {
common.PackerConfig `mapstructure:",squash"` common.PackerConfig `mapstructure:",squash"`
common.HTTPConfig `mapstructure:",squash"` common.HTTPConfig `mapstructure:",squash"`
common.ISOConfig `mapstructure:",squash"` common.ISOConfig `mapstructure:",squash"`
common.FloppyConfig `mapstructure:",squash"`
vmwcommon.DriverConfig `mapstructure:",squash"` vmwcommon.DriverConfig `mapstructure:",squash"`
vmwcommon.OutputConfig `mapstructure:",squash"` vmwcommon.OutputConfig `mapstructure:",squash"`
vmwcommon.RunConfig `mapstructure:",squash"` vmwcommon.RunConfig `mapstructure:",squash"`
@ -40,7 +41,6 @@ type Config struct {
DiskName string `mapstructure:"vmdk_name"` DiskName string `mapstructure:"vmdk_name"`
DiskSize uint `mapstructure:"disk_size"` DiskSize uint `mapstructure:"disk_size"`
DiskTypeId string `mapstructure:"disk_type_id"` DiskTypeId string `mapstructure:"disk_type_id"`
FloppyFiles []string `mapstructure:"floppy_files"`
Format string `mapstructure:"format"` Format string `mapstructure:"format"`
GuestOSType string `mapstructure:"guest_os_type"` GuestOSType string `mapstructure:"guest_os_type"`
Version string `mapstructure:"version"` Version string `mapstructure:"version"`
@ -97,6 +97,7 @@ func (b *Builder) Prepare(raws ...interface{}) ([]string, error) {
errs = packer.MultiErrorAppend(errs, b.config.SSHConfig.Prepare(&b.config.ctx)...) errs = packer.MultiErrorAppend(errs, b.config.SSHConfig.Prepare(&b.config.ctx)...)
errs = packer.MultiErrorAppend(errs, b.config.ToolsConfig.Prepare(&b.config.ctx)...) errs = packer.MultiErrorAppend(errs, b.config.ToolsConfig.Prepare(&b.config.ctx)...)
errs = packer.MultiErrorAppend(errs, b.config.VMXConfig.Prepare(&b.config.ctx)...) errs = packer.MultiErrorAppend(errs, b.config.VMXConfig.Prepare(&b.config.ctx)...)
errs = packer.MultiErrorAppend(errs, b.config.FloppyConfig.Prepare(&b.config.ctx)...)
if b.config.DiskName == "" { if b.config.DiskName == "" {
b.config.DiskName = "disk" b.config.DiskName = "disk"
@ -115,10 +116,6 @@ func (b *Builder) Prepare(raws ...interface{}) ([]string, error) {
} }
} }
if b.config.FloppyFiles == nil {
b.config.FloppyFiles = make([]string, 0)
}
if b.config.GuestOSType == "" { if b.config.GuestOSType == "" {
b.config.GuestOSType = "other" b.config.GuestOSType = "other"
} }

View File

@ -1,6 +1,7 @@
package iso package iso
import ( import (
"fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"reflect" "reflect"
@ -106,7 +107,8 @@ func TestBuilderPrepare_FloppyFiles(t *testing.T) {
t.Fatalf("bad: %#v", b.config.FloppyFiles) t.Fatalf("bad: %#v", b.config.FloppyFiles)
} }
config["floppy_files"] = []string{"foo", "bar"} floppies_path := "../../../common/test-fixtures/floppies"
config["floppy_files"] = []string{fmt.Sprintf("%s/bar.bat", floppies_path), fmt.Sprintf("%s/foo.ps1", floppies_path)}
b = Builder{} b = Builder{}
warns, err = b.Prepare(config) warns, err = b.Prepare(config)
if len(warns) > 0 { if len(warns) > 0 {
@ -116,12 +118,27 @@ func TestBuilderPrepare_FloppyFiles(t *testing.T) {
t.Fatalf("should not have error: %s", err) t.Fatalf("should not have error: %s", err)
} }
expected := []string{"foo", "bar"} expected := []string{fmt.Sprintf("%s/bar.bat", floppies_path), fmt.Sprintf("%s/foo.ps1", floppies_path)}
if !reflect.DeepEqual(b.config.FloppyFiles, expected) { if !reflect.DeepEqual(b.config.FloppyFiles, expected) {
t.Fatalf("bad: %#v", b.config.FloppyFiles) t.Fatalf("bad: %#v", b.config.FloppyFiles)
} }
} }
func TestBuilderPrepare_InvalidFloppies(t *testing.T) {
var b Builder
config := testConfig()
config["floppy_files"] = []string{"nonexistant.bat", "nonexistant.ps1"}
b = Builder{}
_, errs := b.Prepare(config)
if errs == nil {
t.Fatalf("Non existant floppies should trigger multierror")
}
if len(errs.(*packer.MultiError).Errors) != 2 {
t.Fatalf("Multierror should work and report 2 errors")
}
}
func TestBuilderPrepare_Format(t *testing.T) { func TestBuilderPrepare_Format(t *testing.T) {
var b Builder var b Builder
config := testConfig() config := testConfig()

View File

@ -242,6 +242,11 @@ func (ESX5Driver) UpdateVMX(_, password string, port uint, data map[string]strin
func (d *ESX5Driver) CommHost(state multistep.StateBag) (string, error) { func (d *ESX5Driver) CommHost(state multistep.StateBag) (string, error) {
config := state.Get("config").(*Config) config := state.Get("config").(*Config)
sshc := config.SSHConfig.Comm
port := sshc.SSHPort
if sshc.Type == "winrm" {
port = sshc.WinRMPort
}
if address, ok := state.GetOk("vm_address"); ok { if address, ok := state.GetOk("vm_address"); ok {
return address.(string), nil return address.(string), nil
@ -286,7 +291,7 @@ func (d *ESX5Driver) CommHost(state multistep.StateBag) (string, error) {
} }
// When multiple NICs are connected to the same network, choose // When multiple NICs are connected to the same network, choose
// one that has a route back. This Dial should ensure that. // one that has a route back. This Dial should ensure that.
conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", record["IPAddress"], d.Port), 2*time.Second) conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", record["IPAddress"], port), 2*time.Second)
if err != nil { if err != nil {
if e, ok := err.(*net.OpError); ok { if e, ok := err.(*net.OpError); ok {
if e.Timeout() { if e.Timeout() {

View File

@ -1,10 +1,13 @@
package vmx package vmx
import ( import (
"fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"reflect" "reflect"
"testing" "testing"
"github.com/mitchellh/packer/packer"
) )
func TestBuilderPrepare_FloppyFiles(t *testing.T) { func TestBuilderPrepare_FloppyFiles(t *testing.T) {
@ -33,7 +36,8 @@ func TestBuilderPrepare_FloppyFiles(t *testing.T) {
t.Fatalf("bad: %#v", b.config.FloppyFiles) t.Fatalf("bad: %#v", b.config.FloppyFiles)
} }
config["floppy_files"] = []string{"foo", "bar"} floppies_path := "../../../common/test-fixtures/floppies"
config["floppy_files"] = []string{fmt.Sprintf("%s/bar.bat", floppies_path), fmt.Sprintf("%s/foo.ps1", floppies_path)}
b = Builder{} b = Builder{}
warns, err = b.Prepare(config) warns, err = b.Prepare(config)
if len(warns) > 0 { if len(warns) > 0 {
@ -43,8 +47,23 @@ func TestBuilderPrepare_FloppyFiles(t *testing.T) {
t.Fatalf("should not have error: %s", err) t.Fatalf("should not have error: %s", err)
} }
expected := []string{"foo", "bar"} expected := []string{fmt.Sprintf("%s/bar.bat", floppies_path), fmt.Sprintf("%s/foo.ps1", floppies_path)}
if !reflect.DeepEqual(b.config.FloppyFiles, expected) { if !reflect.DeepEqual(b.config.FloppyFiles, expected) {
t.Fatalf("bad: %#v", b.config.FloppyFiles) t.Fatalf("bad: %#v", b.config.FloppyFiles)
} }
} }
func TestBuilderPrepare_InvalidFloppies(t *testing.T) {
var b Builder
config := testConfig(t)
config["floppy_files"] = []string{"nonexistant.bat", "nonexistant.ps1"}
b = Builder{}
_, errs := b.Prepare(config)
if errs == nil {
t.Fatalf("Non existant floppies should trigger multierror")
}
if len(errs.(*packer.MultiError).Errors) != 2 {
t.Fatalf("Multierror should work and report 2 errors")
}
}

View File

@ -15,6 +15,7 @@ import (
type Config struct { type Config struct {
common.PackerConfig `mapstructure:",squash"` common.PackerConfig `mapstructure:",squash"`
common.HTTPConfig `mapstructure:",squash"` common.HTTPConfig `mapstructure:",squash"`
common.FloppyConfig `mapstructure:",squash"`
vmwcommon.DriverConfig `mapstructure:",squash"` vmwcommon.DriverConfig `mapstructure:",squash"`
vmwcommon.OutputConfig `mapstructure:",squash"` vmwcommon.OutputConfig `mapstructure:",squash"`
vmwcommon.RunConfig `mapstructure:",squash"` vmwcommon.RunConfig `mapstructure:",squash"`
@ -24,7 +25,6 @@ type Config struct {
vmwcommon.VMXConfig `mapstructure:",squash"` vmwcommon.VMXConfig `mapstructure:",squash"`
BootCommand []string `mapstructure:"boot_command"` BootCommand []string `mapstructure:"boot_command"`
FloppyFiles []string `mapstructure:"floppy_files"`
RemoteType string `mapstructure:"remote_type"` RemoteType string `mapstructure:"remote_type"`
SkipCompaction bool `mapstructure:"skip_compaction"` SkipCompaction bool `mapstructure:"skip_compaction"`
SourcePath string `mapstructure:"source_path"` SourcePath string `mapstructure:"source_path"`
@ -64,6 +64,7 @@ func NewConfig(raws ...interface{}) (*Config, []string, error) {
errs = packer.MultiErrorAppend(errs, c.SSHConfig.Prepare(&c.ctx)...) errs = packer.MultiErrorAppend(errs, c.SSHConfig.Prepare(&c.ctx)...)
errs = packer.MultiErrorAppend(errs, c.ToolsConfig.Prepare(&c.ctx)...) errs = packer.MultiErrorAppend(errs, c.ToolsConfig.Prepare(&c.ctx)...)
errs = packer.MultiErrorAppend(errs, c.VMXConfig.Prepare(&c.ctx)...) errs = packer.MultiErrorAppend(errs, c.VMXConfig.Prepare(&c.ctx)...)
errs = packer.MultiErrorAppend(errs, c.FloppyConfig.Prepare(&c.ctx)...)
if c.SourcePath == "" { if c.SourcePath == "" {
errs = packer.MultiErrorAppend(errs, fmt.Errorf("source_path is blank, but is required")) errs = packer.MultiErrorAppend(errs, fmt.Errorf("source_path is blank, but is required"))

View File

@ -10,6 +10,7 @@ func testConfig(t *testing.T) map[string]interface{} {
return map[string]interface{}{ return map[string]interface{}{
"ssh_username": "foo", "ssh_username": "foo",
"shutdown_command": "foo", "shutdown_command": "foo",
"source_path": "config_test.go",
} }
} }

View File

@ -1,19 +1,28 @@
package common package common
import ( import (
"fmt"
"os"
"github.com/mitchellh/packer/template/interpolate" "github.com/mitchellh/packer/template/interpolate"
) )
// FloppyConfig is configuration related to created floppy disks and attaching
// them to a Parallels virtual machine.
type FloppyConfig struct { type FloppyConfig struct {
FloppyFiles []string `mapstructure:"floppy_files"` FloppyFiles []string `mapstructure:"floppy_files"`
} }
func (c *FloppyConfig) Prepare(ctx *interpolate.Context) []error { func (c *FloppyConfig) Prepare(ctx *interpolate.Context) []error {
var errs []error
if c.FloppyFiles == nil { if c.FloppyFiles == nil {
c.FloppyFiles = make([]string, 0) c.FloppyFiles = make([]string, 0)
} }
return nil for _, path := range c.FloppyFiles {
if _, err := os.Stat(path); err != nil {
errs = append(errs, fmt.Errorf("Bad Floppy disk file '%s': %s", path, err))
}
}
return errs
} }

View File

@ -0,0 +1,70 @@
package common
import (
"testing"
)
func TestNilFloppies(t *testing.T) {
c := FloppyConfig{}
errs := c.Prepare(nil)
if len(errs) != 0 {
t.Fatal("nil floppies array should not fail")
}
if len(c.FloppyFiles) > 0 {
t.Fatal("struct should not have floppy files")
}
}
func TestEmptyArrayFloppies(t *testing.T) {
c := FloppyConfig{
FloppyFiles: make([]string, 0),
}
errs := c.Prepare(nil)
if len(errs) != 0 {
t.Fatal("empty floppies array should never fail")
}
if len(c.FloppyFiles) > 0 {
t.Fatal("struct should not have floppy files")
}
}
func TestExistingFloppyFile(t *testing.T) {
c := FloppyConfig{
FloppyFiles: []string{"floppy_config.go"},
}
errs := c.Prepare(nil)
if len(errs) != 0 {
t.Fatal("array with existing floppies should not fail")
}
}
func TestNonExistingFloppyFile(t *testing.T) {
c := FloppyConfig{
FloppyFiles: []string{"floppy_config.foo"},
}
errs := c.Prepare(nil)
if len(errs) == 0 {
t.Fatal("array with non existing floppies should return errors")
}
}
func TestMultiErrorFloppyFiles(t *testing.T) {
c := FloppyConfig{
FloppyFiles: []string{"floppy_config.foo", "floppy_config.go", "floppy_config.bar", "floppy_config_test.go", "floppy_config.baz"},
}
errs := c.Prepare(nil)
if len(errs) == 0 {
t.Fatal("array with non existing floppies should return errors")
}
expectedErrors := 3
if count := len(errs); count != expectedErrors {
t.Fatalf("array with %v non existing floppy should return %v errors but it is returning %v", expectedErrors, expectedErrors, count)
}
}

View File

@ -0,0 +1 @@
Echo I am a floppy with a batch file

View File

@ -0,0 +1 @@
Write-Host "I am a floppy with some Powershell"

View File

@ -13,4 +13,3 @@ func TestCommIsCommunicator(t *testing.T) {
t.Fatalf("comm must be a communicator") t.Fatalf("comm must be a communicator")
} }
} }

View File

@ -15,11 +15,13 @@ import (
type Config struct { type Config struct {
common.PackerConfig `mapstructure:",squash"` common.PackerConfig `mapstructure:",squash"`
Login bool Login bool
LoginEmail string `mapstructure:"login_email"` LoginEmail string `mapstructure:"login_email"`
LoginUsername string `mapstructure:"login_username"` LoginUsername string `mapstructure:"login_username"`
LoginPassword string `mapstructure:"login_password"` LoginPassword string `mapstructure:"login_password"`
LoginServer string `mapstructure:"login_server"` LoginServer string `mapstructure:"login_server"`
EcrLogin bool `mapstructure:"ecr_login"`
docker.AwsAccessConfig `mapstructure:",squash"`
ctx interpolate.Context ctx interpolate.Context
} }
@ -42,6 +44,9 @@ func (p *PostProcessor) Configure(raws ...interface{}) error {
return err return err
} }
if p.config.EcrLogin && p.config.LoginServer == "" {
return fmt.Errorf("ECR login requires login server to be provided.")
}
return nil return nil
} }
@ -60,7 +65,19 @@ func (p *PostProcessor) PostProcess(ui packer.Ui, artifact packer.Artifact) (pac
driver = &docker.DockerDriver{Ctx: &p.config.ctx, Ui: ui} driver = &docker.DockerDriver{Ctx: &p.config.ctx, Ui: ui}
} }
if p.config.Login { if p.config.EcrLogin {
ui.Message("Fetching ECR credentials...")
username, password, err := p.config.EcrGetLogin(p.config.LoginServer)
if err != nil {
return nil, false, err
}
p.config.LoginUsername = username
p.config.LoginPassword = password
}
if p.config.Login || p.config.EcrLogin {
ui.Message("Logging in...") ui.Message("Logging in...")
err := driver.Login( err := driver.Login(
p.config.LoginServer, p.config.LoginServer,

View File

@ -81,7 +81,7 @@ func (p *PostProcessor) PostProcess(ui packer.Ui, artifact packer.Artifact) (pac
RawStateTimeout: "5m", RawStateTimeout: "5m",
} }
exporterConfig.CalcTimeout() exporterConfig.CalcTimeout()
// Set up credentials and GCE driver. // Set up credentials and GCE driver.
b, err := ioutil.ReadFile(accountKeyFilePath) b, err := ioutil.ReadFile(accountKeyFilePath)
if err != nil { if err != nil {

View File

@ -52,6 +52,12 @@ type Config struct {
// The optional inventory groups // The optional inventory groups
InventoryGroups []string `mapstructure:"inventory_groups"` InventoryGroups []string `mapstructure:"inventory_groups"`
// The optional ansible-galaxy requirements file
GalaxyFile string `mapstructure:"galaxy_file"`
// The command to run ansible-galaxy
GalaxyCommand string
} }
type Provisioner struct { type Provisioner struct {
@ -74,6 +80,9 @@ func (p *Provisioner) Prepare(raws ...interface{}) error {
if p.config.Command == "" { if p.config.Command == "" {
p.config.Command = "ANSIBLE_FORCE_COLOR=1 PYTHONUNBUFFERED=1 ansible-playbook" p.config.Command = "ANSIBLE_FORCE_COLOR=1 PYTHONUNBUFFERED=1 ansible-playbook"
} }
if p.config.GalaxyCommand == "" {
p.config.GalaxyCommand = "ansible-galaxy"
}
if p.config.StagingDir == "" { if p.config.StagingDir == "" {
p.config.StagingDir = DefaultStagingDir p.config.StagingDir = DefaultStagingDir
@ -94,6 +103,14 @@ func (p *Provisioner) Prepare(raws ...interface{}) error {
} }
} }
// Check that the galaxy file exists, if configured
if len(p.config.GalaxyFile) > 0 {
err = validateFileConfig(p.config.GalaxyFile, "galaxy_file", true)
if err != nil {
errs = packer.MultiErrorAppend(errs, err)
}
}
// Check that the playbook_dir directory exists, if configured // Check that the playbook_dir directory exists, if configured
if len(p.config.PlaybookDir) > 0 { if len(p.config.PlaybookDir) > 0 {
if err := validateDirConfig(p.config.PlaybookDir, "playbook_dir"); err != nil { if err := validateDirConfig(p.config.PlaybookDir, "playbook_dir"); err != nil {
@ -181,6 +198,15 @@ func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error {
}() }()
} }
if len(p.config.GalaxyFile) > 0 {
ui.Message("Uploading galaxy file...")
src = p.config.GalaxyFile
dst = filepath.ToSlash(filepath.Join(p.config.StagingDir, filepath.Base(src)))
if err := p.uploadFile(ui, comm, dst, src); err != nil {
return fmt.Errorf("Error uploading galaxy file: %s", err)
}
}
ui.Message("Uploading inventory file...") ui.Message("Uploading inventory file...")
src = p.config.InventoryFile src = p.config.InventoryFile
dst = filepath.ToSlash(filepath.Join(p.config.StagingDir, filepath.Base(src))) dst = filepath.ToSlash(filepath.Join(p.config.StagingDir, filepath.Base(src)))
@ -242,6 +268,27 @@ func (p *Provisioner) Cancel() {
os.Exit(0) os.Exit(0)
} }
func (p *Provisioner) executeGalaxy(ui packer.Ui, comm packer.Communicator) error {
rolesDir := filepath.ToSlash(filepath.Join(p.config.StagingDir, "roles"))
galaxyFile := filepath.ToSlash(filepath.Join(p.config.StagingDir, filepath.Base(p.config.GalaxyFile)))
// ansible-galaxy install -r requirements.yml -p roles/
command := fmt.Sprintf("cd %s && %s install -r %s -p %s",
p.config.StagingDir, p.config.GalaxyCommand, galaxyFile, rolesDir)
ui.Message(fmt.Sprintf("Executing Ansible Galaxy: %s", command))
cmd := &packer.RemoteCmd{
Command: command,
}
if err := cmd.StartWithUi(comm, ui); err != nil {
return err
}
if cmd.ExitStatus != 0 {
// ansible-galaxy version 2.0.0.2 doesn't return exit codes on error..
return fmt.Errorf("Non-zero exit status: %d", cmd.ExitStatus)
}
return nil
}
func (p *Provisioner) executeAnsible(ui packer.Ui, comm packer.Communicator) error { func (p *Provisioner) executeAnsible(ui packer.Ui, comm packer.Communicator) error {
playbook := filepath.ToSlash(filepath.Join(p.config.StagingDir, filepath.Base(p.config.PlaybookFile))) playbook := filepath.ToSlash(filepath.Join(p.config.StagingDir, filepath.Base(p.config.PlaybookFile)))
inventory := filepath.ToSlash(filepath.Join(p.config.StagingDir, filepath.Base(p.config.InventoryFile))) inventory := filepath.ToSlash(filepath.Join(p.config.StagingDir, filepath.Base(p.config.InventoryFile)))
@ -251,6 +298,13 @@ func (p *Provisioner) executeAnsible(ui packer.Ui, comm packer.Communicator) err
extraArgs = " " + strings.Join(p.config.ExtraArguments, " ") extraArgs = " " + strings.Join(p.config.ExtraArguments, " ")
} }
// Fetch external dependencies
if len(p.config.GalaxyFile) > 0 {
if err := p.executeGalaxy(ui, comm); err != nil {
return fmt.Errorf("Error executing Ansible Galaxy: %s", err)
}
}
command := fmt.Sprintf("cd %s && %s %s%s -c local -i %s", command := fmt.Sprintf("cd %s && %s %s%s -c local -i %s",
p.config.StagingDir, p.config.Command, playbook, extraArgs, inventory) p.config.StagingDir, p.config.Command, playbook, extraArgs, inventory)
ui.Message(fmt.Sprintf("Executing Ansible: %s", command)) ui.Message(fmt.Sprintf("Executing Ansible: %s", command))

View File

@ -8,11 +8,14 @@ import (
"io" "io"
"log" "log"
"net" "net"
"strings"
"github.com/mitchellh/packer/packer" "github.com/mitchellh/packer/packer"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
// An adapter satisfies SSH requests (from an Ansible client) by delegating SSH
// exec and subsystem commands to a packer.Communicator.
type adapter struct { type adapter struct {
done <-chan struct{} done <-chan struct{}
l net.Listener l net.Listener
@ -132,29 +135,15 @@ func (c *adapter) handleSession(newChannel ssh.NewChannel) error {
return return
} }
cmd := &packer.RemoteCmd{ go func(channel ssh.Channel) {
Stdin: channel, exit := c.exec(string(req.Payload), channel, channel, channel.Stderr())
Stdout: channel,
Stderr: channel.Stderr(),
Command: string(req.Payload),
}
if err := c.comm.Start(cmd); err != nil {
c.ui.Error(err.Error())
req.Reply(false, nil)
close(done)
return
}
go func(cmd *packer.RemoteCmd, channel ssh.Channel) {
cmd.Wait()
exitStatus := make([]byte, 4) exitStatus := make([]byte, 4)
binary.BigEndian.PutUint32(exitStatus, uint32(cmd.ExitStatus)) binary.BigEndian.PutUint32(exitStatus, uint32(exit))
channel.SendRequest("exit-status", false, exitStatus) channel.SendRequest("exit-status", false, exitStatus)
close(done) close(done)
}(cmd, channel) }(channel)
req.Reply(true, nil) req.Reply(true, nil)
case "subsystem": case "subsystem":
req, err := newSubsystemRequest(req) req, err := newSubsystemRequest(req)
if err != nil { if err != nil {
@ -170,31 +159,16 @@ func (c *adapter) handleSession(newChannel ssh.NewChannel) error {
if len(sftpCmd) == 0 { if len(sftpCmd) == 0 {
sftpCmd = "/usr/lib/sftp-server -e" sftpCmd = "/usr/lib/sftp-server -e"
} }
cmd := &packer.RemoteCmd{
Stdin: channel,
Stdout: channel,
Stderr: channel.Stderr(),
Command: sftpCmd,
}
c.ui.Say("starting sftp subsystem") c.ui.Say("starting sftp subsystem")
if err := c.comm.Start(cmd); err != nil {
c.ui.Error(err.Error())
req.Reply(false, nil)
close(done)
return
}
req.Reply(true, nil)
go func() { go func() {
cmd.Wait() _ = c.remoteExec(sftpCmd, channel, channel, channel.Stderr())
close(done) close(done)
}() }()
req.Reply(true, nil)
default: default:
c.ui.Error(fmt.Sprintf("unsupported subsystem requested: %s", req.Payload)) c.ui.Error(fmt.Sprintf("unsupported subsystem requested: %s", req.Payload))
req.Reply(false, nil) req.Reply(false, nil)
} }
default: default:
c.ui.Message(fmt.Sprintf("rejecting %s request", req.Type)) c.ui.Message(fmt.Sprintf("rejecting %s request", req.Type))
@ -211,6 +185,57 @@ func (c *adapter) Shutdown() {
c.l.Close() c.l.Close()
} }
func (c *adapter) exec(command string, in io.Reader, out io.Writer, err io.Writer) int {
var exitStatus int
switch {
case strings.HasPrefix(command, "scp ") && serveSCP(command[4:]):
err := c.scpExec(command[4:], in, out, err)
if err != nil {
log.Println(err)
exitStatus = 1
}
default:
exitStatus = c.remoteExec(command, in, out, err)
}
return exitStatus
}
func serveSCP(args string) bool {
opts, _ := scpOptions(args)
return bytes.IndexAny(opts, "tf") >= 0
}
func (c *adapter) scpExec(args string, in io.Reader, out io.Writer, err io.Writer) error {
opts, rest := scpOptions(args)
if i := bytes.IndexByte(opts, 't'); i >= 0 {
return scpUploadSession(opts, rest, in, out, c.comm)
}
if i := bytes.IndexByte(opts, 'f'); i >= 0 {
return scpDownloadSession(opts, rest, in, out, c.comm)
}
return errors.New("no scp mode specified")
}
func (c *adapter) remoteExec(command string, in io.Reader, out io.Writer, err io.Writer) int {
cmd := &packer.RemoteCmd{
Stdin: in,
Stdout: out,
Stderr: err,
Command: command,
}
if err := c.comm.Start(cmd); err != nil {
c.ui.Error(err.Error())
return cmd.ExitStatus
}
cmd.Wait()
return cmd.ExitStatus
}
type envRequest struct { type envRequest struct {
*ssh.Request *ssh.Request
Payload envRequestPayload Payload envRequestPayload

View File

@ -52,6 +52,7 @@ type Config struct {
SSHHostKeyFile string `mapstructure:"ssh_host_key_file"` SSHHostKeyFile string `mapstructure:"ssh_host_key_file"`
SSHAuthorizedKeyFile string `mapstructure:"ssh_authorized_key_file"` SSHAuthorizedKeyFile string `mapstructure:"ssh_authorized_key_file"`
SFTPCmd string `mapstructure:"sftp_command"` SFTPCmd string `mapstructure:"sftp_command"`
UseSFTP bool `mapstructure:"use_sftp"`
inventoryFile string inventoryFile string
} }
@ -106,6 +107,12 @@ func (p *Provisioner) Prepare(raws ...interface{}) error {
log.Println(p.config.SSHHostKeyFile, "does not exist") log.Println(p.config.SSHHostKeyFile, "does not exist")
errs = packer.MultiErrorAppend(errs, err) errs = packer.MultiErrorAppend(errs, err)
} }
} else {
p.config.AnsibleEnvVars = append(p.config.AnsibleEnvVars, "ANSIBLE_HOST_KEY_CHECKING=False")
}
if !p.config.UseSFTP {
p.config.AnsibleEnvVars = append(p.config.AnsibleEnvVars, "ANSIBLE_SCP_IF_SSH=True")
} }
if len(p.config.LocalPort) > 0 { if len(p.config.LocalPort) > 0 {
@ -277,7 +284,7 @@ func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error {
}() }()
} }
if err := p.executeAnsible(ui, comm, k.privKeyFile, !hostSigner.generated); err != nil { if err := p.executeAnsible(ui, comm, k.privKeyFile); err != nil {
return fmt.Errorf("Error executing Ansible: %s", err) return fmt.Errorf("Error executing Ansible: %s", err)
} }
@ -294,7 +301,7 @@ func (p *Provisioner) Cancel() {
os.Exit(0) os.Exit(0)
} }
func (p *Provisioner) executeAnsible(ui packer.Ui, comm packer.Communicator, privKeyFile string, checkHostKey bool) error { func (p *Provisioner) executeAnsible(ui packer.Ui, comm packer.Communicator, privKeyFile string) error {
playbook, _ := filepath.Abs(p.config.PlaybookFile) playbook, _ := filepath.Abs(p.config.PlaybookFile)
inventory := p.config.inventoryFile inventory := p.config.inventoryFile
var envvars []string var envvars []string
@ -315,10 +322,6 @@ func (p *Provisioner) executeAnsible(ui packer.Ui, comm packer.Communicator, pri
cmd.Env = append(cmd.Env, envvars...) cmd.Env = append(cmd.Env, envvars...)
} }
if !checkHostKey {
cmd.Env = append(cmd.Env, "ANSIBLE_HOST_KEY_CHECKING=False")
}
stdout, err := cmd.StdoutPipe() stdout, err := cmd.StdoutPipe()
if err != nil { if err != nil {
return err return err
@ -435,7 +438,6 @@ func newUserKey(pubKeyFile string) (*userKey, error) {
type signer struct { type signer struct {
ssh.Signer ssh.Signer
generated bool
} }
func newSigner(privKeyFile string) (*signer, error) { func newSigner(privKeyFile string) (*signer, error) {
@ -464,7 +466,6 @@ func newSigner(privKeyFile string) (*signer, error) {
if err != nil { if err != nil {
return nil, errors.New("Failed to extract private key from generated key pair") return nil, errors.New("Failed to extract private key from generated key pair")
} }
signer.generated = true
return signer, nil return signer, nil
} }

338
provisioner/ansible/scp.go Normal file
View File

@ -0,0 +1,338 @@
package ansible
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"os"
"path/filepath"
"strings"
"time"
"github.com/mitchellh/packer/packer"
)
const (
scpOK = "\x00"
scpEmptyError = "\x02\n"
)
/*
scp is a simple, but poorly documented, protocol. Thankfully, its source is
freely available, and there is at least one page that describes it reasonably
well.
* https://raw.githubusercontent.com/openssh/openssh-portable/master/scp.c
* https://opensource.apple.com/source/OpenSSH/OpenSSH-7.1/openssh/scp.c
* https://blogs.oracle.com/janp/entry/how_the_scp_protocol_works is a great
resource, but has some bad information. Its first problem is that it doesn't
correctly describe why the producer has to read more responses than messages
it sends (because it has to read the 0 sent by the sink to start the
transfer). The second problem is that it omits that the producer needs to
send a 0 byte after file contents.
*/
func scpUploadSession(opts []byte, rest string, in io.Reader, out io.Writer, comm packer.Communicator) error {
rest = strings.TrimSpace(rest)
if len(rest) == 0 {
fmt.Fprintf(out, scpEmptyError)
return errors.New("no scp target specified")
}
d, err := ioutil.TempDir("", "packer-ansible-upload")
if err != nil {
fmt.Fprintf(out, scpEmptyError)
return err
}
defer os.RemoveAll(d)
state := &scpUploadState{destRoot: rest, srcRoot: d, comm: comm}
fmt.Fprintf(out, scpOK) // signal the client to start the transfer.
return state.Protocol(bufio.NewReader(in), out)
}
func scpDownloadSession(opts []byte, rest string, in io.Reader, out io.Writer, comm packer.Communicator) error {
rest = strings.TrimSpace(rest)
if len(rest) == 0 {
fmt.Fprintf(out, scpEmptyError)
return errors.New("no scp source specified")
}
d, err := ioutil.TempDir("", "packer-ansible-download")
if err != nil {
fmt.Fprintf(out, scpEmptyError)
return err
}
defer os.RemoveAll(d)
if bytes.Contains([]byte{'d'}, opts) {
// the only ansible module that supports downloading via scp is fetch,
// fetch only supports file downloads as of Ansible 2.1.
fmt.Fprintf(out, scpEmptyError)
return errors.New("directory downloads not supported")
}
f, err := os.Create(filepath.Join(d, filepath.Base(rest)))
if err != nil {
fmt.Fprintf(out, scpEmptyError)
return err
}
defer f.Close()
err = comm.Download(rest, f)
if err != nil {
fmt.Fprintf(out, scpEmptyError)
return err
}
state := &scpDownloadState{srcRoot: d}
return state.Protocol(bufio.NewReader(in), out)
}
func (state *scpDownloadState) FileProtocol(path string, info os.FileInfo, in *bufio.Reader, out io.Writer) error {
size := info.Size()
perms := fmt.Sprintf("C%04o", info.Mode().Perm())
fmt.Fprintln(out, perms, size, info.Name())
err := scpResponse(in)
if err != nil {
return err
}
f, err := os.Open(path)
if err != nil {
return err
}
defer f.Close()
io.CopyN(out, f, size)
fmt.Fprintf(out, scpOK)
return scpResponse(in)
}
type scpUploadState struct {
comm packer.Communicator
destRoot string // destRoot is the directory on the target
srcRoot string // srcRoot is the directory on the host
mtime time.Time
atime time.Time
dir string // dir is a path relative to the roots
}
func (scp scpUploadState) DestPath() string {
return filepath.Join(scp.destRoot, scp.dir)
}
func (scp scpUploadState) SrcPath() string {
return filepath.Join(scp.srcRoot, scp.dir)
}
func (state *scpUploadState) Protocol(in *bufio.Reader, out io.Writer) error {
for {
b, err := in.ReadByte()
if err != nil {
return err
}
switch b {
case 'T':
err := state.TimeProtocol(in, out)
if err != nil {
return err
}
case 'C':
return state.FileProtocol(in, out)
case 'E':
state.dir = filepath.Dir(state.dir)
fmt.Fprintf(out, scpOK)
return nil
case 'D':
return state.DirProtocol(in, out)
default:
fmt.Fprintf(out, scpEmptyError)
return fmt.Errorf("unexpected message: %c", b)
}
}
}
func (state *scpUploadState) FileProtocol(in *bufio.Reader, out io.Writer) error {
defer func() {
state.mtime = time.Time{}
}()
var mode os.FileMode
var size int64
var name string
_, err := fmt.Fscanf(in, "%04o %d %s\n", &mode, &size, &name)
if err != nil {
fmt.Fprintf(out, scpEmptyError)
return fmt.Errorf("invalid file message: %v", err)
}
fmt.Fprintf(out, scpOK)
var fi os.FileInfo = fileInfo{name: name, size: size, mode: mode, mtime: state.mtime}
err = state.comm.Upload(filepath.Join(state.DestPath(), fi.Name()), io.LimitReader(in, fi.Size()), &fi)
if err != nil {
fmt.Fprintf(out, scpEmptyError)
return err
}
err = scpResponse(in)
if err != nil {
return err
}
fmt.Fprintf(out, scpOK)
return nil
}
func (state *scpUploadState) TimeProtocol(in *bufio.Reader, out io.Writer) error {
var m, a int64
if _, err := fmt.Fscanf(in, "%d 0 %d 0\n", &m, &a); err != nil {
fmt.Fprintf(out, scpEmptyError)
return err
}
fmt.Fprintf(out, scpOK)
state.atime = time.Unix(a, 0)
state.mtime = time.Unix(m, 0)
return nil
}
func (state *scpUploadState) DirProtocol(in *bufio.Reader, out io.Writer) error {
var mode os.FileMode
var length uint
var name string
if _, err := fmt.Fscanf(in, "%04o %d %s\n", &mode, &length, &name); err != nil {
fmt.Fprintf(out, scpEmptyError)
return fmt.Errorf("invalid directory message: %v", err)
}
fmt.Fprintf(out, scpOK)
path := filepath.Join(state.dir, name)
if err := os.Mkdir(path, mode); err != nil {
return err
}
state.dir = path
if state.atime.IsZero() {
state.atime = time.Now()
}
if state.mtime.IsZero() {
state.mtime = time.Now()
}
if err := os.Chtimes(path, state.atime, state.mtime); err != nil {
return err
}
if err := state.comm.UploadDir(filepath.Dir(state.DestPath()), state.SrcPath(), nil); err != nil {
return err
}
state.mtime = time.Time{}
state.atime = time.Time{}
return state.Protocol(in, out)
}
type scpDownloadState struct {
srcRoot string // srcRoot is the directory on the host
}
func (state *scpDownloadState) Protocol(in *bufio.Reader, out io.Writer) error {
r := bufio.NewReader(in)
// read the byte sent by the other side to start the transfer
scpResponse(r)
return filepath.Walk(state.srcRoot, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if path == state.srcRoot {
return nil
}
if info.IsDir() {
// no need to get fancy; srcRoot should only contain one file, because
// Ansible only allows fetching a single file.
return errors.New("unexpected directory")
}
return state.FileProtocol(path, info, r, out)
})
}
func scpOptions(s string) (opts []byte, rest string) {
end := 0
opt := false
Loop:
for i := 0; i < len(s); i++ {
b := s[i]
switch {
case b == ' ':
opt = false
end++
case b == '-':
opt = true
end++
case opt:
opts = append(opts, b)
end++
default:
break Loop
}
}
rest = s[end:]
return
}
func scpResponse(r *bufio.Reader) error {
code, err := r.ReadByte()
if err != nil {
return err
}
if code != 0 {
message, err := r.ReadString('\n')
if err != nil {
return fmt.Errorf("Error reading error message: %s", err)
}
// 1 is a warning. Anything higher (really just 2) is an error.
if code > 1 {
return errors.New(string(message))
}
log.Println("WARNING:", err)
}
return nil
}
type fileInfo struct {
name string
size int64
mode os.FileMode
mtime time.Time
}
func (fi fileInfo) Name() string { return fi.name }
func (fi fileInfo) Size() int64 { return fi.size }
func (fi fileInfo) Mode() os.FileMode { return fi.mode }
func (fi fileInfo) ModTime() time.Time {
if fi.mtime.IsZero() {
return time.Now()
}
return fi.mtime
}
func (fi fileInfo) IsDir() bool { return fi.mode.IsDir() }
func (fi fileInfo) Sys() interface{} { return nil }

View File

@ -98,9 +98,9 @@ type InstallChefTemplate struct {
} }
type KnifeTemplate struct { type KnifeTemplate struct {
Sudo bool Sudo bool
Flags string Flags string
Args string Args string
} }
func (p *Provisioner) Prepare(raws ...interface{}) error { func (p *Provisioner) Prepare(raws ...interface{}) error {
@ -500,7 +500,7 @@ func (p *Provisioner) knifeExec(ui packer.Ui, comm packer.Communicator, node str
} }
p.config.ctx.Data = &KnifeTemplate{ p.config.ctx.Data = &KnifeTemplate{
Sudo: !p.config.PreventSudo, Sudo: !p.config.PreventSudo,
Flags: strings.Join(flags, " "), Flags: strings.Join(flags, " "),
Args: strings.Join(args, " "), Args: strings.Join(args, " "),
} }

View File

@ -28,12 +28,12 @@ type guestOSTypeConfig struct {
var guestOSTypeConfigs = map[string]guestOSTypeConfig{ var guestOSTypeConfigs = map[string]guestOSTypeConfig{
provisioner.UnixOSType: guestOSTypeConfig{ provisioner.UnixOSType: guestOSTypeConfig{
executeCommand: "{{if .Sudo}}sudo {{end}}chef-solo --no-color -c {{.ConfigPath}} -j {{.JsonPath}}", executeCommand: "{{if .Sudo}}sudo {{end}}chef-solo --no-color -c {{.ConfigPath}} -j {{.JsonPath}}",
installCommand: "curl -L https://www.chef.io/chef/install.sh | {{if .Sudo}}sudo {{end}}bash", installCommand: "curl -L https://omnitruck.chef.io/install.sh | {{if .Sudo}}sudo {{end}}bash",
stagingDir: "/tmp/packer-chef-client", stagingDir: "/tmp/packer-chef-client",
}, },
provisioner.WindowsOSType: guestOSTypeConfig{ provisioner.WindowsOSType: guestOSTypeConfig{
executeCommand: "c:/opscode/chef/bin/chef-solo.bat --no-color -c {{.ConfigPath}} -j {{.JsonPath}}", executeCommand: "c:/opscode/chef/bin/chef-solo.bat --no-color -c {{.ConfigPath}} -j {{.JsonPath}}",
installCommand: "powershell.exe -Command \"(New-Object System.Net.WebClient).DownloadFile('http://chef.io/chef/install.msi', 'C:\\Windows\\Temp\\chef.msi');Start-Process 'msiexec' -ArgumentList '/qb /i C:\\Windows\\Temp\\chef.msi' -NoNewWindow -Wait\"", installCommand: "powershell.exe -Command \". { iwr -useb https://omnitruck.chef.io/install.ps1 } | iex; install\"",
stagingDir: "C:/Windows/Temp/packer-chef-client", stagingDir: "C:/Windows/Temp/packer-chef-client",
}, },
} }

View File

@ -16,7 +16,7 @@ import (
type Config struct { type Config struct {
common.PackerConfig `mapstructure:",squash"` common.PackerConfig `mapstructure:",squash"`
ctx interpolate.Context ctx interpolate.Context
// The command used to execute Puppet. // The command used to execute Puppet.
ExecuteCommand string `mapstructure:"execute_command"` ExecuteCommand string `mapstructure:"execute_command"`
@ -76,11 +76,11 @@ func (p *Provisioner) Prepare(raws ...interface{}) error {
if err != nil { if err != nil {
return err return err
} }
if p.config.ExecuteCommand == "" { if p.config.ExecuteCommand == "" {
p.config.ExecuteCommand = p.commandTemplate() p.config.ExecuteCommand = p.commandTemplate()
} }
if p.config.StagingDir == "" { if p.config.StagingDir == "" {
p.config.StagingDir = "/tmp/packer-puppet-server" p.config.StagingDir = "/tmp/packer-puppet-server"
} }

View File

@ -4,9 +4,9 @@ import (
"github.com/mitchellh/packer/packer" "github.com/mitchellh/packer/packer"
"io/ioutil" "io/ioutil"
"os" "os"
"testing"
"strings"
"regexp" "regexp"
"strings"
"testing"
) )
func testConfig() map[string]interface{} { func testConfig() map[string]interface{} {
@ -306,7 +306,7 @@ func TestProvisioner_RemotePathSetViaRemotePathAndRemoteFile(t *testing.T) {
t.Fatalf("should not have error: %s", err) t.Fatalf("should not have error: %s", err)
} }
if p.config.RemotePath != expectedRemoteFolder + "/" + expectedRemoteFile { if p.config.RemotePath != expectedRemoteFolder+"/"+expectedRemoteFile {
t.Fatalf("remote path does not contain remote_file") t.Fatalf("remote path does not contain remote_file")
} }
} }

View File

@ -7,28 +7,26 @@ load test_helper
fixtures builder-googlecompute fixtures builder-googlecompute
# Required parameters # Required parameters
: ${GC_BUCKET_NAME:?}
: ${GC_ACCOUNT_FILE:?} : ${GC_ACCOUNT_FILE:?}
: ${GC_PROJECT_ID:?} : ${GC_PROJECT_ID:?}
command -v gcutil >/dev/null 2>&1 || { command -v gcloud >/dev/null 2>&1 || {
echo "'gcutil' must be installed" >&2 echo "'gcloud' must be installed" >&2
exit 1 exit 1
} }
USER_VARS="-var bucket_name=${GC_BUCKET_NAME}"
USER_VARS="${USER_VARS} -var account_file=${GC_ACCOUNT_FILE}" USER_VARS="${USER_VARS} -var account_file=${GC_ACCOUNT_FILE}"
USER_VARS="${USER_VARS} -var project_id=${GC_PROJECT_ID}" USER_VARS="${USER_VARS} -var project_id=${GC_PROJECT_ID}"
# This tests if GCE has an image that contains the given parameter. # This tests if GCE has an image that contains the given parameter.
gc_has_image() { gc_has_image() {
gcutil --format=names --project=${GC_PROJECT_ID} listimages \ gcloud compute --format='table[no-heading](name)' --project=${GC_PROJECT_ID} images list \
| grep $1 | wc -l | grep $1 | wc -l
} }
teardown() { teardown() {
gcutil --format=names --project=${GC_PROJECT_ID} listimages \ gcloud compute --format='table[no-heading](name)' --project=${GC_PROJECT_ID} images list \
| grep packerbats \ | grep packerbats \
| xargs -n1 gcutil --project=${GC_PROJECT_ID} deleteimage --force | xargs -n1 gcloud compute --project=${GC_PROJECT_ID} images delete
} }
@test "googlecompute: build minimal.json" { @test "googlecompute: build minimal.json" {

View File

@ -1,13 +1,11 @@
{ {
"variables": { "variables": {
"bucket_name": null,
"account_file": null, "account_file": null,
"project_id": null "project_id": null
}, },
"builders": [{ "builders": [{
"type": "googlecompute", "type": "googlecompute",
"bucket_name": "{{user `bucket_name`}}",
"account_file": "{{user `account_file`}}", "account_file": "{{user `account_file`}}",
"project_id": "{{user `project_id`}}", "project_id": "{{user `project_id`}}",

View File

@ -0,0 +1,40 @@
{
"variables": {},
"provisioners": [
{
"type": "shell-local",
"command": "echo 'TODO(bhcleek): write the public key to $HOME/.ssh/known_hosts and stop using ANSIBLE_HOST_KEY_CHECKING=False'"
}, {
"type": "shell",
"inline": [
"apt-get update",
"apt-get -y install python openssh-sftp-server",
"ls -l /usr/lib"
]
}, {
"type": "ansible",
"playbook_file": "./playbook.yml",
"extra_arguments": [
"-vvvv", "--private-key", "ansible-test-id"
],
"sftp_command": "/usr/lib/sftp-server -e -l INFO",
"use_sftp": true,
"ansible_env_vars": ["PACKER_ANSIBLE_TEST=1", "ANSIBLE_HOST_KEY_CHECKING=False"],
"groups": ["PACKER_TEST"],
"empty_groups": ["PACKER_EMPTY_GROUP"],
"host_alias": "packer-test",
"user": "packer",
"local_port": 2222,
"ssh_host_key_file": "ansible-server.key",
"ssh_authorized_key_file": "ansible-test-id.pub"
}
],
"builders": [{
"type": "googlecompute",
"account_file": "{{user `account_file`}}",
"project_id": "{{user `project_id`}}",
"image_name": "packerbats-alloptions-{{timestamp}}",
"source_image": "debian-7-wheezy-v20141108",
"zone": "us-central1-a"
}]
}

View File

@ -0,0 +1 @@
This is a file

Binary file not shown.

View File

@ -0,0 +1,19 @@
{
"variables": {},
"provisioners": [
{
"type": "ansible",
"playbook_file": "./playbook.yml"
}
],
"builders": [
{
"type": "googlecompute",
"account_file": "{{user `account_file`}}",
"project_id": "{{user `project_id`}}",
"image_name": "packerbats-minimal-{{timestamp}}",
"source_image": "debian-7-wheezy-v20141108",
"zone": "us-central1-a"
}
]
}

View File

@ -0,0 +1,13 @@
---
- hosts: default:packer-test
gather_facts: no
tasks:
- raw: touch /root/ansible-raw-test
- raw: date
- command: echo "the command module"
- command: mkdir /tmp/remote-dir
args:
creates: /tmp/remote-dir
- copy: src=dir/file.txt dest=/tmp/remote-dir/file.txt
- fetch: src=/tmp/remote-dir/file.txt dest=fetched-dir validate=yes fail_on_missing=yes
- copy: src=largish-file.txt dest=/tmp/largish-file.txt

View File

@ -0,0 +1,23 @@
{
"variables": {},
"provisioners": [
{
"type": "ansible",
"playbook_file": "./playbook.yml",
"extra_arguments": [
"-vvvv"
],
"sftp_command": "/usr/bin/false"
}
],
"builders": [
{
"type": "googlecompute",
"account_file": "{{user `account_file`}}",
"project_id": "{{user `project_id`}}",
"image_name": "packerbats-scp-{{timestamp}}",
"source_image": "debian-7-wheezy-v20141108",
"zone": "us-central1-a"
}
]
}

View File

@ -0,0 +1,30 @@
{
"variables": {},
"provisioners": [
{
"type": "shell",
"inline": [
"apt-get update",
"apt-get -y install python openssh-sftp-server",
"ls -l /usr/lib",
"#/usr/lib/sftp-server -?"
]
}, {
"type": "ansible",
"playbook_file": "./playbook.yml",
"sftp_command": "/usr/lib/sftp-server -e -l INFO",
"use_sftp": true
}
],
"builders": [
{
"type": "googlecompute",
"account_file": "{{user `account_file`}}",
"project_id": "{{user `project_id`}}",
"image_name": "packerbats-sftp-{{timestamp}}",
"source_image": "debian-7-wheezy-v20141108",
"zone": "us-central1-a"
}
]
}

77
test/provisioner_ansible.bats Executable file
View File

@ -0,0 +1,77 @@
#!/usr/bin/env bats
#
# This tests the ansible provisioner on Google Cloud Provider (i.e.
# googlecompute). The teardown function will delete any images with the text
# "packerbats" within the name.
load test_helper
fixtures provisioner-ansible
# Required parameters
: ${GC_ACCOUNT_FILE:?}
: ${GC_PROJECT_ID:?}
command -v gcloud >/dev/null 2>&1 || {
echo "'gcloud' must be installed" >&2
exit 1
}
USER_VARS="${USER_VARS} -var account_file=${GC_ACCOUNT_FILE}"
USER_VARS="${USER_VARS} -var project_id=${GC_PROJECT_ID}"
# This tests if GCE has an image that contains the given parameter.
gc_has_image() {
gcloud compute --format='table[no-heading](name)' --project=${GC_PROJECT_ID} images list \
| grep $1 | wc -l
}
setup(){
rm -f $FIXTURE_ROOT/ansible-test-id
rm -f $FIXTURE_ROOT/ansible-server.key
ssh-keygen -N "" -f $FIXTURE_ROOT/ansible-test-id
ssh-keygen -N "" -f $FIXTURE_ROOT/ansible-server.key
}
teardown() {
gcloud compute --format='table[no-heading](name)' --project=${GC_PROJECT_ID} images list \
| grep packerbats \
| xargs -n1 gcloud compute --project=${GC_PROJECT_ID} images delete
rm -f $FIXTURE_ROOT/ansible-test-id
rm -f $FIXTURE_ROOT/ansible-test-id.pub
rm -f $FIXTURE_ROOT/ansible-server.key
rm -f $FIXTURE_ROOT/ansible-server.key.pub
rm -rf $FIXTURE_ROOT/fetched-dir
}
@test "ansible provisioner: build minimal.json" {
cd $FIXTURE_ROOT
run packer build ${USER_VARS} $FIXTURE_ROOT/minimal.json
[ "$status" -eq 0 ]
[ "$(gc_has_image "packerbats-minimal")" -eq 1 ]
diff -r dir fetched-dir/default/tmp/remote-dir > /dev/null
}
@test "ansible provisioner: build all_options.json" {
cd $FIXTURE_ROOT
run packer build ${USER_VARS} $FIXTURE_ROOT/all_options.json
[ "$status" -eq 0 ]
[ "$(gc_has_image "packerbats-alloptions")" -eq 1 ]
diff -r dir fetched-dir/packer-test/tmp/remote-dir > /dev/null
}
@test "ansible provisioner: build scp.json" {
cd $FIXTURE_ROOT
run packer build ${USER_VARS} $FIXTURE_ROOT/scp.json
[ "$status" -eq 0 ]
[ "$(gc_has_image "packerbats-scp")" -eq 1 ]
diff -r dir fetched-dir/default/tmp/remote-dir > /dev/null
}
@test "ansible provisioner: build sftp.json" {
cd $FIXTURE_ROOT
run packer build ${USER_VARS} $FIXTURE_ROOT/sftp.json
[ "$status" -eq 0 ]
[ "$(gc_has_image "packerbats-sftp")" -eq 1 ]
diff -r dir fetched-dir/default/tmp/remote-dir > /dev/null
}

View File

@ -42,9 +42,12 @@ type Error interface {
OrigErr() error OrigErr() error
} }
// BatchError is a batch of errors which also wraps lower level errors with code, message, // BatchError is a batch of errors which also wraps lower level errors with
// and original errors. Calling Error() will only return the error that is at the end // code, message, and original errors. Calling Error() will include all errors
// of the list. // that occurred in the batch.
//
// Deprecated: Replaced with BatchedErrors. Only defined for backwards
// compatibility.
type BatchError interface { type BatchError interface {
// Satisfy the generic error interface. // Satisfy the generic error interface.
error error
@ -59,17 +62,35 @@ type BatchError interface {
OrigErrs() []error OrigErrs() []error
} }
// BatchedErrors is a batch of errors which also wraps lower level errors with
// code, message, and original errors. Calling Error() will include all errors
// that occurred in the batch.
//
// Replaces BatchError
type BatchedErrors interface {
// Satisfy the base Error interface.
Error
// Returns the original error if one was set. Nil is returned if not set.
OrigErrs() []error
}
// New returns an Error object described by the code, message, and origErr. // New returns an Error object described by the code, message, and origErr.
// //
// If origErr satisfies the Error interface it will not be wrapped within a new // If origErr satisfies the Error interface it will not be wrapped within a new
// Error object and will instead be returned. // Error object and will instead be returned.
func New(code, message string, origErr error) Error { func New(code, message string, origErr error) Error {
return newBaseError(code, message, origErr) var errs []error
if origErr != nil {
errs = append(errs, origErr)
}
return newBaseError(code, message, errs)
} }
// NewBatchError returns an baseError with an expectation of an array of errors // NewBatchError returns an BatchedErrors with a collection of errors as an
func NewBatchError(code, message string, errs []error) BatchError { // array of errors.
return newBaseErrors(code, message, errs) func NewBatchError(code, message string, errs []error) BatchedErrors {
return newBaseError(code, message, errs)
} }
// A RequestFailure is an interface to extract request failure information from // A RequestFailure is an interface to extract request failure information from
@ -82,9 +103,9 @@ func NewBatchError(code, message string, errs []error) BatchError {
// output, err := s3manage.Upload(svc, input, opts) // output, err := s3manage.Upload(svc, input, opts)
// if err != nil { // if err != nil {
// if reqerr, ok := err.(RequestFailure); ok { // if reqerr, ok := err.(RequestFailure); ok {
// log.Printf("Request failed", reqerr.Code(), reqerr.Message(), reqerr.RequestID()) // log.Println("Request failed", reqerr.Code(), reqerr.Message(), reqerr.RequestID())
// } else { // } else {
// log.Printf("Error:", err.Error() // log.Println("Error:", err.Error())
// } // }
// } // }
// //

View File

@ -34,36 +34,17 @@ type baseError struct {
errs []error errs []error
} }
// newBaseError returns an error object for the code, message, and err. // newBaseError returns an error object for the code, message, and errors.
// //
// code is a short no whitespace phrase depicting the classification of // code is a short no whitespace phrase depicting the classification of
// the error that is being created. // the error that is being created.
// //
// message is the free flow string containing detailed information about the error. // message is the free flow string containing detailed information about the
// error.
// //
// origErr is the error object which will be nested under the new error to be returned. // origErrs is the error objects which will be nested under the new errors to
func newBaseError(code, message string, origErr error) *baseError { // be returned.
b := &baseError{ func newBaseError(code, message string, origErrs []error) *baseError {
code: code,
message: message,
}
if origErr != nil {
b.errs = append(b.errs, origErr)
}
return b
}
// newBaseErrors returns an error object for the code, message, and errors.
//
// code is a short no whitespace phrase depicting the classification of
// the error that is being created.
//
// message is the free flow string containing detailed information about the error.
//
// origErrs is the error objects which will be nested under the new errors to be returned.
func newBaseErrors(code, message string, origErrs []error) *baseError {
b := &baseError{ b := &baseError{
code: code, code: code,
message: message, message: message,
@ -103,19 +84,26 @@ func (b baseError) Message() string {
return b.message return b.message
} }
// OrigErr returns the original error if one was set. Nil is returned if no error // OrigErr returns the original error if one was set. Nil is returned if no
// was set. This only returns the first element in the list. If the full list is // error was set. This only returns the first element in the list. If the full
// needed, use BatchError // list is needed, use BatchedErrors.
func (b baseError) OrigErr() error { func (b baseError) OrigErr() error {
if size := len(b.errs); size > 0 { switch len(b.errs) {
case 0:
return nil
case 1:
return b.errs[0] return b.errs[0]
default:
if err, ok := b.errs[0].(Error); ok {
return NewBatchError(err.Code(), err.Message(), b.errs[1:])
}
return NewBatchError("BatchedErrors",
"multiple errors occurred", b.errs)
} }
return nil
} }
// OrigErrs returns the original errors if one was set. An empty slice is returned if // OrigErrs returns the original errors if one was set. An empty slice is
// no error was set:w // returned if no error was set.
func (b baseError) OrigErrs() []error { func (b baseError) OrigErrs() []error {
return b.errs return b.errs
} }
@ -133,8 +121,8 @@ type requestError struct {
requestID string requestID string
} }
// newRequestError returns a wrapped error with additional information for request // newRequestError returns a wrapped error with additional information for
// status code, and service requestID. // request status code, and service requestID.
// //
// Should be used to wrap all request which involve service requests. Even if // Should be used to wrap all request which involve service requests. Even if
// the request failed without a service response, but had an HTTP status code // the request failed without a service response, but had an HTTP status code
@ -173,6 +161,15 @@ func (r requestError) RequestID() string {
return r.requestID return r.requestID
} }
// OrigErrs returns the original errors if one was set. An empty slice is
// returned if no error was set.
func (r requestError) OrigErrs() []error {
if b, ok := r.awsError.(BatchedErrors); ok {
return b.OrigErrs()
}
return []error{r.OrigErr()}
}
// An error list that satisfies the golang interface // An error list that satisfies the golang interface
type errorList []error type errorList []error

View File

@ -3,6 +3,7 @@ package awsutil
import ( import (
"io" "io"
"reflect" "reflect"
"time"
) )
// Copy deeply copies a src structure to dst. Useful for copying request and // Copy deeply copies a src structure to dst. Useful for copying request and
@ -49,7 +50,14 @@ func rcopy(dst, src reflect.Value, root bool) {
} else { } else {
e := src.Type().Elem() e := src.Type().Elem()
if dst.CanSet() && !src.IsNil() { if dst.CanSet() && !src.IsNil() {
dst.Set(reflect.New(e)) if _, ok := src.Interface().(*time.Time); !ok {
dst.Set(reflect.New(e))
} else {
tempValue := reflect.New(e)
tempValue.Elem().Set(src.Elem())
// Sets time.Time's unexported values
dst.Set(tempValue)
}
} }
if src.Elem().IsValid() { if src.Elem().IsValid() {
// Keep the current root state since the depth hasn't changed // Keep the current root state since the depth hasn't changed

View File

@ -91,6 +91,10 @@ func prettify(v reflect.Value, indent int, buf *bytes.Buffer) {
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}") buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
default: default:
if !v.IsValid() {
fmt.Fprint(buf, "<invalid value>")
return
}
format := "%v" format := "%v"
switch v.Interface().(type) { switch v.Interface().(type) {
case string: case string:

View File

@ -87,9 +87,18 @@ const logReqMsg = `DEBUG: Request %s/%s Details:
%s %s
-----------------------------------------------------` -----------------------------------------------------`
const logReqErrMsg = `DEBUG ERROR: Request %s/%s:
---[ REQUEST DUMP ERROR ]-----------------------------
%s
-----------------------------------------------------`
func logRequest(r *request.Request) { func logRequest(r *request.Request) {
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody) logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
dumpedBody, _ := httputil.DumpRequestOut(r.HTTPRequest, logBody) dumpedBody, err := httputil.DumpRequestOut(r.HTTPRequest, logBody)
if err != nil {
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg, r.ClientInfo.ServiceName, r.Operation.Name, err))
return
}
if logBody { if logBody {
// Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's // Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's
@ -107,11 +116,21 @@ const logRespMsg = `DEBUG: Response %s/%s Details:
%s %s
-----------------------------------------------------` -----------------------------------------------------`
const logRespErrMsg = `DEBUG ERROR: Response %s/%s:
---[ RESPONSE DUMP ERROR ]-----------------------------
%s
-----------------------------------------------------`
func logResponse(r *request.Request) { func logResponse(r *request.Request) {
var msg = "no response data" var msg = "no response data"
if r.HTTPResponse != nil { if r.HTTPResponse != nil {
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody) logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
dumpedBody, _ := httputil.DumpResponse(r.HTTPResponse, logBody) dumpedBody, err := httputil.DumpResponse(r.HTTPResponse, logBody)
if err != nil {
r.Config.Logger.Log(fmt.Sprintf(logRespErrMsg, r.ClientInfo.ServiceName, r.Operation.Name, err))
return
}
msg = string(dumpedBody) msg = string(dumpedBody)
} else if r.Error != nil { } else if r.Error != nil {
msg = r.Error.Error() msg = r.Error.Error()

View File

@ -1,8 +1,8 @@
package client package client
import ( import (
"math"
"math/rand" "math/rand"
"sync"
"time" "time"
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
@ -30,16 +30,61 @@ func (d DefaultRetryer) MaxRetries() int {
return d.NumMaxRetries return d.NumMaxRetries
} }
var seededRand = rand.New(&lockedSource{src: rand.NewSource(time.Now().UnixNano())})
// RetryRules returns the delay duration before retrying this request again // RetryRules returns the delay duration before retrying this request again
func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration { func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration {
delay := int(math.Pow(2, float64(r.RetryCount))) * (rand.Intn(30) + 30) // Set the upper limit of delay in retrying at ~five minutes
minTime := 30
throttle := d.shouldThrottle(r)
if throttle {
minTime = 500
}
retryCount := r.RetryCount
if retryCount > 13 {
retryCount = 13
} else if throttle && retryCount > 8 {
retryCount = 8
}
delay := (1 << uint(retryCount)) * (seededRand.Intn(minTime) + minTime)
return time.Duration(delay) * time.Millisecond return time.Duration(delay) * time.Millisecond
} }
// ShouldRetry returns if the request should be retried. // ShouldRetry returns true if the request should be retried.
func (d DefaultRetryer) ShouldRetry(r *request.Request) bool { func (d DefaultRetryer) ShouldRetry(r *request.Request) bool {
if r.HTTPResponse.StatusCode >= 500 { if r.HTTPResponse.StatusCode >= 500 {
return true return true
} }
return r.IsErrorRetryable() return r.IsErrorRetryable() || d.shouldThrottle(r)
}
// ShouldThrottle returns true if the request should be throttled.
func (d DefaultRetryer) shouldThrottle(r *request.Request) bool {
if r.HTTPResponse.StatusCode == 502 ||
r.HTTPResponse.StatusCode == 503 ||
r.HTTPResponse.StatusCode == 504 {
return true
}
return r.IsErrorThrottle()
}
// lockedSource is a thread-safe implementation of rand.Source
type lockedSource struct {
lk sync.Mutex
src rand.Source
}
func (r *lockedSource) Int63() (n int64) {
r.lk.Lock()
n = r.src.Int63()
r.lk.Unlock()
return
}
func (r *lockedSource) Seed(seed int64) {
r.lk.Lock()
r.src.Seed(seed)
r.lk.Unlock()
} }

View File

@ -7,24 +7,36 @@ import (
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
) )
// UseServiceDefaultRetries instructs the config to use the service's own default // UseServiceDefaultRetries instructs the config to use the service's own
// number of retries. This will be the default action if Config.MaxRetries // default number of retries. This will be the default action if
// is nil also. // Config.MaxRetries is nil also.
const UseServiceDefaultRetries = -1 const UseServiceDefaultRetries = -1
// RequestRetryer is an alias for a type that implements the request.Retryer interface. // RequestRetryer is an alias for a type that implements the request.Retryer
// interface.
type RequestRetryer interface{} type RequestRetryer interface{}
// A Config provides service configuration for service clients. By default, // A Config provides service configuration for service clients. By default,
// all clients will use the {defaults.DefaultConfig} structure. // all clients will use the defaults.DefaultConfig tructure.
//
// // Create Session with MaxRetry configuration to be shared by multiple
// // service clients.
// sess, err := session.NewSession(&aws.Config{
// MaxRetries: aws.Int(3),
// })
//
// // Create S3 service client with a specific Region.
// svc := s3.New(sess, &aws.Config{
// Region: aws.String("us-west-2"),
// })
type Config struct { type Config struct {
// Enables verbose error printing of all credential chain errors. // Enables verbose error printing of all credential chain errors.
// Should be used when wanting to see all errors while attempting to retreive // Should be used when wanting to see all errors while attempting to
// credentials. // retrieve credentials.
CredentialsChainVerboseErrors *bool CredentialsChainVerboseErrors *bool
// The credentials object to use when signing requests. Defaults to // The credentials object to use when signing requests. Defaults to a
// a chain of credential providers to search for credentials in environment // chain of credential providers to search for credentials in environment
// variables, shared credential file, and EC2 Instance Roles. // variables, shared credential file, and EC2 Instance Roles.
Credentials *credentials.Credentials Credentials *credentials.Credentials
@ -63,11 +75,12 @@ type Config struct {
Logger Logger Logger Logger
// The maximum number of times that a request will be retried for failures. // The maximum number of times that a request will be retried for failures.
// Defaults to -1, which defers the max retry setting to the service specific // Defaults to -1, which defers the max retry setting to the service
// configuration. // specific configuration.
MaxRetries *int MaxRetries *int
// Retryer guides how HTTP requests should be retried in case of recoverable failures. // Retryer guides how HTTP requests should be retried in case of
// recoverable failures.
// //
// When nil or the value does not implement the request.Retryer interface, // When nil or the value does not implement the request.Retryer interface,
// the request.DefaultRetryer will be used. // the request.DefaultRetryer will be used.
@ -82,8 +95,8 @@ type Config struct {
// //
Retryer RequestRetryer Retryer RequestRetryer
// Disables semantic parameter validation, which validates input for missing // Disables semantic parameter validation, which validates input for
// required fields and/or other semantic request input errors. // missing required fields and/or other semantic request input errors.
DisableParamValidation *bool DisableParamValidation *bool
// Disables the computation of request and response checksums, e.g., // Disables the computation of request and response checksums, e.g.,
@ -91,8 +104,8 @@ type Config struct {
DisableComputeChecksums *bool DisableComputeChecksums *bool
// Set this to `true` to force the request to use path-style addressing, // Set this to `true` to force the request to use path-style addressing,
// i.e., `http://s3.amazonaws.com/BUCKET/KEY`. By default, the S3 client will // i.e., `http://s3.amazonaws.com/BUCKET/KEY`. By default, the S3 client
// use virtual hosted bucket addressing when possible // will use virtual hosted bucket addressing when possible
// (`http://BUCKET.s3.amazonaws.com/KEY`). // (`http://BUCKET.s3.amazonaws.com/KEY`).
// //
// @note This configuration option is specific to the Amazon S3 service. // @note This configuration option is specific to the Amazon S3 service.
@ -100,28 +113,93 @@ type Config struct {
// Amazon S3: Virtual Hosting of Buckets // Amazon S3: Virtual Hosting of Buckets
S3ForcePathStyle *bool S3ForcePathStyle *bool
// Set this to `true` to disable the EC2Metadata client from overriding the // Set this to `true` to disable the SDK adding the `Expect: 100-Continue`
// default http.Client's Timeout. This is helpful if you do not want the EC2Metadata // header to PUT requests over 2MB of content. 100-Continue instructs the
// client to create a new http.Client. This options is only meaningful if you're not // HTTP client not to send the body until the service responds with a
// already using a custom HTTP client with the SDK. Enabled by default. // `continue` status. This is useful to prevent sending the request body
// until after the request is authenticated, and validated.
// //
// Must be set and provided to the session.New() in order to disable the EC2Metadata // http://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectPUT.html
// overriding the timeout for default credentials chain. //
// 100-Continue is only enabled for Go 1.6 and above. See `http.Transport`'s
// `ExpectContinueTimeout` for information on adjusting the continue wait
// timeout. https://golang.org/pkg/net/http/#Transport
//
// You should use this flag to disble 100-Continue if you experience issues
// with proxies or third party S3 compatible services.
S3Disable100Continue *bool
// Set this to `true` to enable S3 Accelerate feature. For all operations
// compatible with S3 Accelerate will use the accelerate endpoint for
// requests. Requests not compatible will fall back to normal S3 requests.
//
// The bucket must be enable for accelerate to be used with S3 client with
// accelerate enabled. If the bucket is not enabled for accelerate an error
// will be returned. The bucket name must be DNS compatible to also work
// with accelerate.
//
// Not compatible with UseDualStack requests will fail if both flags are
// specified.
S3UseAccelerate *bool
// Set this to `true` to disable the EC2Metadata client from overriding the
// default http.Client's Timeout. This is helpful if you do not want the
// EC2Metadata client to create a new http.Client. This options is only
// meaningful if you're not already using a custom HTTP client with the
// SDK. Enabled by default.
//
// Must be set and provided to the session.NewSession() in order to disable
// the EC2Metadata overriding the timeout for default credentials chain.
// //
// Example: // Example:
// sess := session.New(aws.NewConfig().WithEC2MetadataDiableTimeoutOverride(true)) // sess, err := session.NewSession(aws.NewConfig().WithEC2MetadataDiableTimeoutOverride(true))
//
// svc := s3.New(sess) // svc := s3.New(sess)
// //
EC2MetadataDisableTimeoutOverride *bool EC2MetadataDisableTimeoutOverride *bool
// Instructs the endpiont to be generated for a service client to
// be the dual stack endpoint. The dual stack endpoint will support
// both IPv4 and IPv6 addressing.
//
// Setting this for a service which does not support dual stack will fail
// to make requets. It is not recommended to set this value on the session
// as it will apply to all service clients created with the session. Even
// services which don't support dual stack endpoints.
//
// If the Endpoint config value is also provided the UseDualStack flag
// will be ignored.
//
// Only supported with.
//
// sess, err := session.NewSession()
//
// svc := s3.New(sess, &aws.Config{
// UseDualStack: aws.Bool(true),
// })
UseDualStack *bool
// SleepDelay is an override for the func the SDK will call when sleeping
// during the lifecycle of a request. Specifically this will be used for
// request delays. This value should only be used for testing. To adjust
// the delay of a request see the aws/client.DefaultRetryer and
// aws/request.Retryer.
SleepDelay func(time.Duration) SleepDelay func(time.Duration)
} }
// NewConfig returns a new Config pointer that can be chained with builder methods to // NewConfig returns a new Config pointer that can be chained with builder
// set multiple configuration values inline without using pointers. // methods to set multiple configuration values inline without using pointers.
// //
// svc := s3.New(aws.NewConfig().WithRegion("us-west-2").WithMaxRetries(10)) // // Create Session with MaxRetry configuration to be shared by multiple
// // service clients.
// sess, err := session.NewSession(aws.NewConfig().
// WithMaxRetries(3),
// )
// //
// // Create S3 service client with a specific Region.
// svc := s3.New(sess, aws.NewConfig().
// WithRegion("us-west-2"),
// )
func NewConfig() *Config { func NewConfig() *Config {
return &Config{} return &Config{}
} }
@ -210,6 +288,27 @@ func (c *Config) WithS3ForcePathStyle(force bool) *Config {
return c return c
} }
// WithS3Disable100Continue sets a config S3Disable100Continue value returning
// a Config pointer for chaining.
func (c *Config) WithS3Disable100Continue(disable bool) *Config {
c.S3Disable100Continue = &disable
return c
}
// WithS3UseAccelerate sets a config S3UseAccelerate value returning a Config
// pointer for chaining.
func (c *Config) WithS3UseAccelerate(enable bool) *Config {
c.S3UseAccelerate = &enable
return c
}
// WithUseDualStack sets a config UseDualStack value returning a Config
// pointer for chaining.
func (c *Config) WithUseDualStack(enable bool) *Config {
c.UseDualStack = &enable
return c
}
// WithEC2MetadataDisableTimeoutOverride sets a config EC2MetadataDisableTimeoutOverride value // WithEC2MetadataDisableTimeoutOverride sets a config EC2MetadataDisableTimeoutOverride value
// returning a Config pointer for chaining. // returning a Config pointer for chaining.
func (c *Config) WithEC2MetadataDisableTimeoutOverride(enable bool) *Config { func (c *Config) WithEC2MetadataDisableTimeoutOverride(enable bool) *Config {
@ -288,6 +387,18 @@ func mergeInConfig(dst *Config, other *Config) {
dst.S3ForcePathStyle = other.S3ForcePathStyle dst.S3ForcePathStyle = other.S3ForcePathStyle
} }
if other.S3Disable100Continue != nil {
dst.S3Disable100Continue = other.S3Disable100Continue
}
if other.S3UseAccelerate != nil {
dst.S3UseAccelerate = other.S3UseAccelerate
}
if other.UseDualStack != nil {
dst.UseDualStack = other.UseDualStack
}
if other.EC2MetadataDisableTimeoutOverride != nil { if other.EC2MetadataDisableTimeoutOverride != nil {
dst.EC2MetadataDisableTimeoutOverride = other.EC2MetadataDisableTimeoutOverride dst.EC2MetadataDisableTimeoutOverride = other.EC2MetadataDisableTimeoutOverride
} }

View File

@ -2,7 +2,7 @@ package aws
import "time" import "time"
// String returns a pointer to of the string value passed in. // String returns a pointer to the string value passed in.
func String(v string) *string { func String(v string) *string {
return &v return &v
} }
@ -61,7 +61,7 @@ func StringValueMap(src map[string]*string) map[string]string {
return dst return dst
} }
// Bool returns a pointer to of the bool value passed in. // Bool returns a pointer to the bool value passed in.
func Bool(v bool) *bool { func Bool(v bool) *bool {
return &v return &v
} }
@ -120,7 +120,7 @@ func BoolValueMap(src map[string]*bool) map[string]bool {
return dst return dst
} }
// Int returns a pointer to of the int value passed in. // Int returns a pointer to the int value passed in.
func Int(v int) *int { func Int(v int) *int {
return &v return &v
} }
@ -179,7 +179,7 @@ func IntValueMap(src map[string]*int) map[string]int {
return dst return dst
} }
// Int64 returns a pointer to of the int64 value passed in. // Int64 returns a pointer to the int64 value passed in.
func Int64(v int64) *int64 { func Int64(v int64) *int64 {
return &v return &v
} }
@ -238,7 +238,7 @@ func Int64ValueMap(src map[string]*int64) map[string]int64 {
return dst return dst
} }
// Float64 returns a pointer to of the float64 value passed in. // Float64 returns a pointer to the float64 value passed in.
func Float64(v float64) *float64 { func Float64(v float64) *float64 {
return &v return &v
} }
@ -297,7 +297,7 @@ func Float64ValueMap(src map[string]*float64) map[string]float64 {
return dst return dst
} }
// Time returns a pointer to of the time.Time value passed in. // Time returns a pointer to the time.Time value passed in.
func Time(v time.Time) *time.Time { func Time(v time.Time) *time.Time {
return &v return &v
} }
@ -311,6 +311,18 @@ func TimeValue(v *time.Time) time.Time {
return time.Time{} return time.Time{}
} }
// TimeUnixMilli returns a Unix timestamp in milliseconds from "January 1, 1970 UTC".
// The result is undefined if the Unix time cannot be represented by an int64.
// Which includes calling TimeUnixMilli on a zero Time is undefined.
//
// This utility is useful for service API's such as CloudWatch Logs which require
// their unix time values to be in milliseconds.
//
// See Go stdlib https://golang.org/pkg/time/#Time.UnixNano for more information.
func TimeUnixMilli(t time.Time) int64 {
return t.UnixNano() / int64(time.Millisecond/time.Nanosecond)
}
// TimeSlice converts a slice of time.Time values into a slice of // TimeSlice converts a slice of time.Time values into a slice of
// time.Time pointers // time.Time pointers
func TimeSlice(src []time.Time) []*time.Time { func TimeSlice(src []time.Time) []*time.Time {

View File

@ -24,30 +24,38 @@ type lener interface {
// BuildContentLengthHandler builds the content length of a request based on the body, // BuildContentLengthHandler builds the content length of a request based on the body,
// or will use the HTTPRequest.Header's "Content-Length" if defined. If unable // or will use the HTTPRequest.Header's "Content-Length" if defined. If unable
// to determine request body length and no "Content-Length" was specified it will panic. // to determine request body length and no "Content-Length" was specified it will panic.
//
// The Content-Length will only be aded to the request if the length of the body
// is greater than 0. If the body is empty or the current `Content-Length`
// header is <= 0, the header will also be stripped.
var BuildContentLengthHandler = request.NamedHandler{Name: "core.BuildContentLengthHandler", Fn: func(r *request.Request) { var BuildContentLengthHandler = request.NamedHandler{Name: "core.BuildContentLengthHandler", Fn: func(r *request.Request) {
if slength := r.HTTPRequest.Header.Get("Content-Length"); slength != "" {
length, _ := strconv.ParseInt(slength, 10, 64)
r.HTTPRequest.ContentLength = length
return
}
var length int64 var length int64
switch body := r.Body.(type) {
case nil: if slength := r.HTTPRequest.Header.Get("Content-Length"); slength != "" {
length = 0 length, _ = strconv.ParseInt(slength, 10, 64)
case lener: } else {
length = int64(body.Len()) switch body := r.Body.(type) {
case io.Seeker: case nil:
r.BodyStart, _ = body.Seek(0, 1) length = 0
end, _ := body.Seek(0, 2) case lener:
body.Seek(r.BodyStart, 0) // make sure to seek back to original location length = int64(body.Len())
length = end - r.BodyStart case io.Seeker:
default: r.BodyStart, _ = body.Seek(0, 1)
panic("Cannot get length of body, must provide `ContentLength`") end, _ := body.Seek(0, 2)
body.Seek(r.BodyStart, 0) // make sure to seek back to original location
length = end - r.BodyStart
default:
panic("Cannot get length of body, must provide `ContentLength`")
}
} }
r.HTTPRequest.ContentLength = length if length > 0 {
r.HTTPRequest.Header.Set("Content-Length", fmt.Sprintf("%d", length)) r.HTTPRequest.ContentLength = length
r.HTTPRequest.Header.Set("Content-Length", fmt.Sprintf("%d", length))
} else {
r.HTTPRequest.ContentLength = 0
r.HTTPRequest.Header.Del("Content-Length")
}
}} }}
// SDKVersionUserAgentHandler is a request handler for adding the SDK Version to the user agent. // SDKVersionUserAgentHandler is a request handler for adding the SDK Version to the user agent.
@ -64,6 +72,11 @@ var SendHandler = request.NamedHandler{Name: "core.SendHandler", Fn: func(r *req
var err error var err error
r.HTTPResponse, err = r.Config.HTTPClient.Do(r.HTTPRequest) r.HTTPResponse, err = r.Config.HTTPClient.Do(r.HTTPRequest)
if err != nil { if err != nil {
// Prevent leaking if an HTTPResponse was returned. Clean up
// the body.
if r.HTTPResponse != nil {
r.HTTPResponse.Body.Close()
}
// Capture the case where url.Error is returned for error processing // Capture the case where url.Error is returned for error processing
// response. e.g. 301 without location header comes back as string // response. e.g. 301 without location header comes back as string
// error and r.HTTPResponse is nil. Other url redirect errors will // error and r.HTTPResponse is nil. Other url redirect errors will

View File

@ -1,153 +1,17 @@
package corehandlers package corehandlers
import ( import "github.com/aws/aws-sdk-go/aws/request"
"fmt"
"reflect"
"strconv"
"strings"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
// ValidateParametersHandler is a request handler to validate the input parameters. // ValidateParametersHandler is a request handler to validate the input parameters.
// Validating parameters only has meaning if done prior to the request being sent. // Validating parameters only has meaning if done prior to the request being sent.
var ValidateParametersHandler = request.NamedHandler{Name: "core.ValidateParametersHandler", Fn: func(r *request.Request) { var ValidateParametersHandler = request.NamedHandler{Name: "core.ValidateParametersHandler", Fn: func(r *request.Request) {
if r.ParamsFilled() { if !r.ParamsFilled() {
v := validator{errors: []string{}}
v.validateAny(reflect.ValueOf(r.Params), "")
if count := len(v.errors); count > 0 {
format := "%d validation errors:\n- %s"
msg := fmt.Sprintf(format, count, strings.Join(v.errors, "\n- "))
r.Error = awserr.New("InvalidParameter", msg, nil)
}
}
}}
// A validator validates values. Collects validations errors which occurs.
type validator struct {
errors []string
}
// There's no validation to be done on the contents of []byte values. Prepare
// to check validateAny arguments against that type so we can quickly skip
// them.
var byteSliceType = reflect.TypeOf([]byte(nil))
// validateAny will validate any struct, slice or map type. All validations
// are also performed recursively for nested types.
func (v *validator) validateAny(value reflect.Value, path string) {
value = reflect.Indirect(value)
if !value.IsValid() {
return return
} }
switch value.Kind() { if v, ok := r.Params.(request.Validator); ok {
case reflect.Struct: if err := v.Validate(); err != nil {
v.validateStruct(value, path) r.Error = err
case reflect.Slice:
if value.Type() == byteSliceType {
// We don't need to validate the contents of []byte.
return
}
for i := 0; i < value.Len(); i++ {
v.validateAny(value.Index(i), path+fmt.Sprintf("[%d]", i))
}
case reflect.Map:
for _, n := range value.MapKeys() {
v.validateAny(value.MapIndex(n), path+fmt.Sprintf("[%q]", n.String()))
} }
} }
} }}
// validateStruct will validate the struct value's fields. If the structure has
// nested types those types will be validated also.
func (v *validator) validateStruct(value reflect.Value, path string) {
prefix := "."
if path == "" {
prefix = ""
}
for i := 0; i < value.Type().NumField(); i++ {
f := value.Type().Field(i)
if strings.ToLower(f.Name[0:1]) == f.Name[0:1] {
continue
}
fvalue := value.FieldByName(f.Name)
err := validateField(f, fvalue, validateFieldRequired, validateFieldMin)
if err != nil {
v.errors = append(v.errors, fmt.Sprintf("%s: %s", err.Error(), path+prefix+f.Name))
continue
}
v.validateAny(fvalue, path+prefix+f.Name)
}
}
type validatorFunc func(f reflect.StructField, fvalue reflect.Value) error
func validateField(f reflect.StructField, fvalue reflect.Value, funcs ...validatorFunc) error {
for _, fn := range funcs {
if err := fn(f, fvalue); err != nil {
return err
}
}
return nil
}
// Validates that a field has a valid value provided for required fields.
func validateFieldRequired(f reflect.StructField, fvalue reflect.Value) error {
if f.Tag.Get("required") == "" {
return nil
}
switch fvalue.Kind() {
case reflect.Ptr, reflect.Slice, reflect.Map:
if fvalue.IsNil() {
return fmt.Errorf("missing required parameter")
}
default:
if !fvalue.IsValid() {
return fmt.Errorf("missing required parameter")
}
}
return nil
}
// Validates that if a value is provided for a field, that value must be at
// least a minimum length.
func validateFieldMin(f reflect.StructField, fvalue reflect.Value) error {
minStr := f.Tag.Get("min")
if minStr == "" {
return nil
}
min, _ := strconv.ParseInt(minStr, 10, 64)
kind := fvalue.Kind()
if kind == reflect.Ptr {
if fvalue.IsNil() {
return nil
}
fvalue = fvalue.Elem()
}
switch fvalue.Kind() {
case reflect.String:
if int64(fvalue.Len()) < min {
return fmt.Errorf("field too short, minimum length %d", min)
}
case reflect.Slice, reflect.Map:
if fvalue.IsNil() {
return nil
}
if int64(fvalue.Len()) < min {
return fmt.Errorf("field too short, minimum length %d", min)
}
// TODO min can also apply to number minimum value.
}
return nil
}

View File

@ -0,0 +1,191 @@
// Package endpointcreds provides support for retrieving credentials from an
// arbitrary HTTP endpoint.
//
// The credentials endpoint Provider can receive both static and refreshable
// credentials that will expire. Credentials are static when an "Expiration"
// value is not provided in the endpoint's response.
//
// Static credentials will never expire once they have been retrieved. The format
// of the static credentials response:
// {
// "AccessKeyId" : "MUA...",
// "SecretAccessKey" : "/7PC5om....",
// }
//
// Refreshable credentials will expire within the "ExpiryWindow" of the Expiration
// value in the response. The format of the refreshable credentials response:
// {
// "AccessKeyId" : "MUA...",
// "SecretAccessKey" : "/7PC5om....",
// "Token" : "AQoDY....=",
// "Expiration" : "2016-02-25T06:03:31Z"
// }
//
// Errors should be returned in the following format and only returned with 400
// or 500 HTTP status codes.
// {
// "code": "ErrorCode",
// "message": "Helpful error message."
// }
package endpointcreds
import (
"encoding/json"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
)
// ProviderName is the name of the credentials provider.
const ProviderName = `CredentialsEndpointProvider`
// Provider satisfies the credentials.Provider interface, and is a client to
// retrieve credentials from an arbitrary endpoint.
type Provider struct {
staticCreds bool
credentials.Expiry
// Requires a AWS Client to make HTTP requests to the endpoint with.
// the Endpoint the request will be made to is provided by the aws.Config's
// Endpoint value.
Client *client.Client
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
// due to ExpiredTokenException exceptions.
//
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired.
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
}
// NewProviderClient returns a credentials Provider for retrieving AWS credentials
// from arbitrary endpoint.
func NewProviderClient(cfg aws.Config, handlers request.Handlers, endpoint string, options ...func(*Provider)) credentials.Provider {
p := &Provider{
Client: client.New(
cfg,
metadata.ClientInfo{
ServiceName: "CredentialsEndpoint",
Endpoint: endpoint,
},
handlers,
),
}
p.Client.Handlers.Unmarshal.PushBack(unmarshalHandler)
p.Client.Handlers.UnmarshalError.PushBack(unmarshalError)
p.Client.Handlers.Validate.Clear()
p.Client.Handlers.Validate.PushBack(validateEndpointHandler)
for _, option := range options {
option(p)
}
return p
}
// NewCredentialsClient returns a Credentials wrapper for retrieving credentials
// from an arbitrary endpoint concurrently. The client will request the
func NewCredentialsClient(cfg aws.Config, handlers request.Handlers, endpoint string, options ...func(*Provider)) *credentials.Credentials {
return credentials.NewCredentials(NewProviderClient(cfg, handlers, endpoint, options...))
}
// IsExpired returns true if the credentials retrieved are expired, or not yet
// retrieved.
func (p *Provider) IsExpired() bool {
if p.staticCreds {
return false
}
return p.Expiry.IsExpired()
}
// Retrieve will attempt to request the credentials from the endpoint the Provider
// was configured for. And error will be returned if the retrieval fails.
func (p *Provider) Retrieve() (credentials.Value, error) {
resp, err := p.getCredentials()
if err != nil {
return credentials.Value{ProviderName: ProviderName},
awserr.New("CredentialsEndpointError", "failed to load credentials", err)
}
if resp.Expiration != nil {
p.SetExpiration(*resp.Expiration, p.ExpiryWindow)
} else {
p.staticCreds = true
}
return credentials.Value{
AccessKeyID: resp.AccessKeyID,
SecretAccessKey: resp.SecretAccessKey,
SessionToken: resp.Token,
ProviderName: ProviderName,
}, nil
}
type getCredentialsOutput struct {
Expiration *time.Time
AccessKeyID string
SecretAccessKey string
Token string
}
type errorOutput struct {
Code string `json:"code"`
Message string `json:"message"`
}
func (p *Provider) getCredentials() (*getCredentialsOutput, error) {
op := &request.Operation{
Name: "GetCredentials",
HTTPMethod: "GET",
}
out := &getCredentialsOutput{}
req := p.Client.NewRequest(op, nil, out)
req.HTTPRequest.Header.Set("Accept", "application/json")
return out, req.Send()
}
func validateEndpointHandler(r *request.Request) {
if len(r.ClientInfo.Endpoint) == 0 {
r.Error = aws.ErrMissingEndpoint
}
}
func unmarshalHandler(r *request.Request) {
defer r.HTTPResponse.Body.Close()
out := r.Data.(*getCredentialsOutput)
if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&out); err != nil {
r.Error = awserr.New("SerializationError",
"failed to decode endpoint credentials",
err,
)
}
}
func unmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close()
var errOut errorOutput
if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&errOut); err != nil {
r.Error = awserr.New("SerializationError",
"failed to decode endpoint credentials",
err,
)
}
// Response body format is not consistent between metadata endpoints.
// Grab the error message as a string and include that as the source error
r.Error = awserr.New(errOut.Code, errOut.Message, nil)
}

View File

@ -14,7 +14,7 @@ var (
ErrStaticCredentialsEmpty = awserr.New("EmptyStaticCreds", "static credentials are empty", nil) ErrStaticCredentialsEmpty = awserr.New("EmptyStaticCreds", "static credentials are empty", nil)
) )
// A StaticProvider is a set of credentials which are set pragmatically, // A StaticProvider is a set of credentials which are set programmatically,
// and will never expire. // and will never expire.
type StaticProvider struct { type StaticProvider struct {
Value Value
@ -30,13 +30,22 @@ func NewStaticCredentials(id, secret, token string) *Credentials {
}}) }})
} }
// NewStaticCredentialsFromCreds returns a pointer to a new Credentials object
// wrapping the static credentials value provide. Same as NewStaticCredentials
// but takes the creds Value instead of individual fields
func NewStaticCredentialsFromCreds(creds Value) *Credentials {
return NewCredentials(&StaticProvider{Value: creds})
}
// Retrieve returns the credentials or error if the credentials are invalid. // Retrieve returns the credentials or error if the credentials are invalid.
func (s *StaticProvider) Retrieve() (Value, error) { func (s *StaticProvider) Retrieve() (Value, error) {
if s.AccessKeyID == "" || s.SecretAccessKey == "" { if s.AccessKeyID == "" || s.SecretAccessKey == "" {
return Value{ProviderName: StaticProviderName}, ErrStaticCredentialsEmpty return Value{ProviderName: StaticProviderName}, ErrStaticCredentialsEmpty
} }
s.Value.ProviderName = StaticProviderName if len(s.Value.ProviderName) == 0 {
s.Value.ProviderName = StaticProviderName
}
return s.Value, nil return s.Value, nil
} }

View File

@ -0,0 +1,161 @@
// Package stscreds are credential Providers to retrieve STS AWS credentials.
//
// STS provides multiple ways to retrieve credentials which can be used when making
// future AWS service API operation calls.
package stscreds
import (
"fmt"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
)
// ProviderName provides a name of AssumeRole provider
const ProviderName = "AssumeRoleProvider"
// AssumeRoler represents the minimal subset of the STS client API used by this provider.
type AssumeRoler interface {
AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error)
}
// DefaultDuration is the default amount of time in minutes that the credentials
// will be valid for.
var DefaultDuration = time.Duration(15) * time.Minute
// AssumeRoleProvider retrieves temporary credentials from the STS service, and
// keeps track of their expiration time. This provider must be used explicitly,
// as it is not included in the credentials chain.
type AssumeRoleProvider struct {
credentials.Expiry
// STS client to make assume role request with.
Client AssumeRoler
// Role to be assumed.
RoleARN string
// Session name, if you wish to reuse the credentials elsewhere.
RoleSessionName string
// Expiry duration of the STS credentials. Defaults to 15 minutes if not set.
Duration time.Duration
// Optional ExternalID to pass along, defaults to nil if not set.
ExternalID *string
// The policy plain text must be 2048 bytes or shorter. However, an internal
// conversion compresses it into a packed binary format with a separate limit.
// The PackedPolicySize response element indicates by percentage how close to
// the upper size limit the policy is, with 100% equaling the maximum allowed
// size.
Policy *string
// The identification number of the MFA device that is associated with the user
// who is making the AssumeRole call. Specify this value if the trust policy
// of the role being assumed includes a condition that requires MFA authentication.
// The value is either the serial number for a hardware device (such as GAHT12345678)
// or an Amazon Resource Name (ARN) for a virtual device (such as arn:aws:iam::123456789012:mfa/user).
SerialNumber *string
// The value provided by the MFA device, if the trust policy of the role being
// assumed requires MFA (that is, if the policy includes a condition that tests
// for MFA). If the role being assumed requires MFA and if the TokenCode value
// is missing or expired, the AssumeRole call returns an "access denied" error.
TokenCode *string
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
// due to ExpiredTokenException exceptions.
//
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired.
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
}
// NewCredentials returns a pointer to a new Credentials object wrapping the
// AssumeRoleProvider. The credentials will expire every 15 minutes and the
// role will be named after a nanosecond timestamp of this operation.
//
// Takes a Config provider to create the STS client. The ConfigProvider is
// satisfied by the session.Session type.
func NewCredentials(c client.ConfigProvider, roleARN string, options ...func(*AssumeRoleProvider)) *credentials.Credentials {
p := &AssumeRoleProvider{
Client: sts.New(c),
RoleARN: roleARN,
Duration: DefaultDuration,
}
for _, option := range options {
option(p)
}
return credentials.NewCredentials(p)
}
// NewCredentialsWithClient returns a pointer to a new Credentials object wrapping the
// AssumeRoleProvider. The credentials will expire every 15 minutes and the
// role will be named after a nanosecond timestamp of this operation.
//
// Takes an AssumeRoler which can be satisfiede by the STS client.
func NewCredentialsWithClient(svc AssumeRoler, roleARN string, options ...func(*AssumeRoleProvider)) *credentials.Credentials {
p := &AssumeRoleProvider{
Client: svc,
RoleARN: roleARN,
Duration: DefaultDuration,
}
for _, option := range options {
option(p)
}
return credentials.NewCredentials(p)
}
// Retrieve generates a new set of temporary credentials using STS.
func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
// Apply defaults where parameters are not set.
if p.RoleSessionName == "" {
// Try to work out a role name that will hopefully end up unique.
p.RoleSessionName = fmt.Sprintf("%d", time.Now().UTC().UnixNano())
}
if p.Duration == 0 {
// Expire as often as AWS permits.
p.Duration = DefaultDuration
}
input := &sts.AssumeRoleInput{
DurationSeconds: aws.Int64(int64(p.Duration / time.Second)),
RoleArn: aws.String(p.RoleARN),
RoleSessionName: aws.String(p.RoleSessionName),
ExternalId: p.ExternalID,
}
if p.Policy != nil {
input.Policy = p.Policy
}
if p.SerialNumber != nil && p.TokenCode != nil {
input.SerialNumber = p.SerialNumber
input.TokenCode = p.TokenCode
}
roleOutput, err := p.Client.AssumeRole(input)
if err != nil {
return credentials.Value{ProviderName: ProviderName}, err
}
// We will proactively generate new credentials before they expire.
p.SetExpiration(*roleOutput.Credentials.Expiration, p.ExpiryWindow)
return credentials.Value{
AccessKeyID: *roleOutput.Credentials.AccessKeyId,
SecretAccessKey: *roleOutput.Credentials.SecretAccessKey,
SessionToken: *roleOutput.Credentials.SessionToken,
ProviderName: ProviderName,
}, nil
}

Some files were not shown because too many files have changed in this diff Show More