diff --git a/packer/provisioner.go b/packer/provisioner.go index 3592fb919..7f8b70757 100644 --- a/packer/provisioner.go +++ b/packer/provisioner.go @@ -2,6 +2,7 @@ package packer import ( "sync" + "time" ) // A provisioner is responsible for installing and configuring software @@ -65,3 +66,79 @@ func (h *ProvisionHook) Cancel() { h.runningProvisioner.Cancel() } } + +// PausedProvisioner is a Provisioner implementation that pauses before +// the provisioner is actually run. +type PausedProvisioner struct { + PauseBefore time.Duration + Provisioner Provisioner + + cancelCh chan struct{} + doneCh chan struct{} + lock sync.Mutex +} + +func (p *PausedProvisioner) Prepare(raws ...interface{}) error { + return p.Provisioner.Prepare(raws...) +} + +func (p *PausedProvisioner) Provision(ui Ui, comm Communicator) error { + p.lock.Lock() + cancelCh := make(chan struct{}) + p.cancelCh = cancelCh + + // Setup the done channel, which is trigger when we're done + doneCh := make(chan struct{}) + defer close(doneCh) + p.doneCh = doneCh + p.lock.Unlock() + + defer func() { + p.lock.Lock() + defer p.lock.Unlock() + if p.cancelCh == cancelCh { + p.cancelCh = nil + } + if p.doneCh == doneCh { + p.doneCh = nil + } + }() + + // Use a select to determine if we get cancelled during the wait + select { + case <-time.After(p.PauseBefore): + case <-cancelCh: + return nil + } + + provDoneCh := make(chan error, 1) + go p.provision(provDoneCh, ui, comm) + + select { + case err := <-provDoneCh: + return err + case <-cancelCh: + p.Provisioner.Cancel() + return <-provDoneCh + } +} + +func (p *PausedProvisioner) Cancel() { + var doneCh chan struct{} + + p.lock.Lock() + if p.cancelCh != nil { + close(p.cancelCh) + p.cancelCh = nil + } + if p.doneCh != nil { + doneCh = p.doneCh + } + p.lock.Unlock() + + <-doneCh +} + +func (p *PausedProvisioner) provision(result chan<- error, ui Ui, comm Communicator) { + result <- p.Provisioner.Provision(ui, comm) +} diff --git a/packer/provisioner_mock.go b/packer/provisioner_mock.go index b61f642af..62b304ccb 100644 --- a/packer/provisioner_mock.go +++ b/packer/provisioner_mock.go @@ -5,11 +5,12 @@ package packer type MockProvisioner struct { ProvFunc func() error - PrepCalled bool - PrepConfigs []interface{} - ProvCalled bool - ProvUi Ui - CancelCalled bool + PrepCalled bool + PrepConfigs []interface{} + ProvCalled bool + ProvCommunicator Communicator + ProvUi Ui + CancelCalled bool } func (t *MockProvisioner) Prepare(configs ...interface{}) error { @@ -20,6 +21,7 @@ func (t *MockProvisioner) Prepare(configs ...interface{}) error { func (t *MockProvisioner) Provision(ui Ui, comm Communicator) error { t.ProvCalled = true + t.ProvCommunicator = comm t.ProvUi = ui if t.ProvFunc == nil { diff --git a/packer/provisioner_test.go b/packer/provisioner_test.go index a3d97d511..5eeebb4a3 100644 --- a/packer/provisioner_test.go +++ b/packer/provisioner_test.go @@ -80,3 +80,94 @@ func TestProvisionHook_cancel(t *testing.T) { } // TODO(mitchellh): Test that they're run in the proper order + +func TestPausedProvisioner_impl(t *testing.T) { + var _ Provisioner = new(PausedProvisioner) +} + +func TestPausedProvisionerPrepare(t *testing.T) { + mock := new(MockProvisioner) + prov := &PausedProvisioner{ + Provisioner: mock, + } + + prov.Prepare(42) + if !mock.PrepCalled { + t.Fatal("prepare should be called") + } + if mock.PrepConfigs[0] != 42 { + t.Fatal("should have proper configs") + } +} + +func TestPausedProvisionerProvision(t *testing.T) { + mock := new(MockProvisioner) + prov := &PausedProvisioner{ + Provisioner: mock, + } + + ui := testUi() + comm := new(MockCommunicator) + prov.Provision(ui, comm) + if !mock.ProvCalled { + t.Fatal("prov should be called") + } + if mock.ProvUi != ui { + t.Fatal("should have proper ui") + } + if mock.ProvCommunicator != comm { + t.Fatal("should have proper comm") + } +} + +func TestPausedProvisionerProvision_waits(t *testing.T) { + mock := new(MockProvisioner) + prov := &PausedProvisioner{ + PauseBefore: 50 * time.Millisecond, + Provisioner: mock, + } + + dataCh := make(chan struct{}) + mock.ProvFunc = func() error { + close(dataCh) + return nil + } + + go prov.Provision(testUi(), new(MockCommunicator)) + + select { + case <-time.After(10 * time.Millisecond): + case <-dataCh: + t.Fatal("should not be called") + } + + select { + case <-time.After(100 * time.Millisecond): + t.Fatal("never called") + case <-dataCh: + } +} + +func TestPausedProvisionerCancel(t *testing.T) { + mock := new(MockProvisioner) + prov := &PausedProvisioner{ + Provisioner: mock, + } + + provCh := make(chan struct{}) + mock.ProvFunc = func() error { + close(provCh) + time.Sleep(10 * time.Millisecond) + return nil + } + + // Start provisioning and wait for it to start + go prov.Provision(testUi(), new(MockCommunicator)) + <-provCh + + // Cancel it + prov.Cancel() + if !mock.CancelCalled { + t.Fatal("cancel should be called") + } +} diff --git a/packer/template_test.go b/packer/template_test.go index 5ab991d16..51c7f5bb8 100644 --- a/packer/template_test.go +++ b/packer/template_test.go @@ -473,7 +473,7 @@ func TestParseTemplate_ProvisionerPauseBefore(t *testing.T) { if result.Provisioners[0].Type != "shell" { t.Fatalf("bad: %#v", result.Provisioners[0].Type) } - if result.Provisioners[0].pauseBefore != 10 * time.Second { + if result.Provisioners[0].pauseBefore != 10*time.Second { t.Fatalf("bad: %s", result.Provisioners[0].pauseBefore) } }