diff --git a/packer/plugin/provisioner.go b/packer/plugin/provisioner.go index 91017062b..9c3b8fb42 100644 --- a/packer/plugin/provisioner.go +++ b/packer/plugin/provisioner.go @@ -1,6 +1,7 @@ package plugin import ( + "context" "log" "github.com/hashicorp/packer/packer" @@ -20,22 +21,13 @@ func (c *cmdProvisioner) Prepare(configs ...interface{}) error { return c.p.Prepare(configs...) } -func (c *cmdProvisioner) Provision(ui packer.Ui, comm packer.Communicator) error { +func (c *cmdProvisioner) Provision(ctx context.Context, ui packer.Ui, comm packer.Communicator) error { defer func() { r := recover() c.checkExit(r, nil) }() - return c.p.Provision(ui, comm) -} - -func (c *cmdProvisioner) Cancel() { - defer func() { - r := recover() - c.checkExit(r, nil) - }() - - c.p.Cancel() + return c.p.Provision(ctx, ui, comm) } func (c *cmdProvisioner) checkExit(p interface{}, cb func()) { diff --git a/packer/provisioner.go b/packer/provisioner.go index 666829ea3..ec350fd50 100644 --- a/packer/provisioner.go +++ b/packer/provisioner.go @@ -16,16 +16,11 @@ type Provisioner interface { // should be merged in some sane way. Prepare(...interface{}) error - // Provision is called to actually provision the machine. A UI is - // given to communicate with the user, and a communicator is given that - // is guaranteed to be connected to some machine so that provisioning - // can be done. + // Provision is called to actually provision the machine. A context is + // given for cancellation, a UI is given to communicate with the user, and + // a communicator is given that is guaranteed to be connected to some + // machine so that provisioning can be done. Provision(context.Context, Ui, Communicator) error - - // Cancel is called to cancel the provisioning. This is usually called - // while Provision is still being called. The Provisioner should act - // to stop its execution as quickly as possible in a race-free way. - Cancel() } // A HookedProvisioner represents a provisioner and information describing it @@ -40,13 +35,10 @@ type ProvisionHook struct { // The provisioners to run as part of the hook. These should already // be prepared (by calling Prepare) at some earlier stage. Provisioners []*HookedProvisioner - - lock sync.Mutex - runningProvisioner Provisioner } // Runs the provisioners in order. -func (h *ProvisionHook) Run(name string, ui Ui, comm Communicator, data interface{}) error { +func (h *ProvisionHook) Run(ctx context.Context, name string, ui Ui, comm Communicator, data interface{}) error { // Shortcut if len(h.Provisioners) == 0 { return nil @@ -59,21 +51,10 @@ func (h *ProvisionHook) Run(name string, ui Ui, comm Communicator, data interfac "then a communicator is required. Please fix this to continue.") } - defer func() { - h.lock.Lock() - defer h.lock.Unlock() - - h.runningProvisioner = nil - }() - for _, p := range h.Provisioners { - h.lock.Lock() - h.runningProvisioner = p.Provisioner - h.lock.Unlock() - ts := CheckpointReporter.AddSpan(p.TypeName, "provisioner", p.Config) - err := p.Provisioner.Provision(ui, comm) + err := p.Provisioner.Provision(ctx, ui, comm) ts.End(err) if err != nil { @@ -84,91 +65,28 @@ func (h *ProvisionHook) Run(name string, ui Ui, comm Communicator, data interfac return nil } -// Cancels the provisioners that are still running. -func (h *ProvisionHook) Cancel() { - h.lock.Lock() - defer h.lock.Unlock() - - if h.runningProvisioner != nil { - h.runningProvisioner.Cancel() - } -} - // PausedProvisioner is a Provisioner implementation that pauses before // the provisioner is actually run. type PausedProvisioner struct { PauseBefore time.Duration Provisioner Provisioner - - cancelCh chan struct{} - doneCh chan struct{} - lock sync.Mutex } func (p *PausedProvisioner) Prepare(raws ...interface{}) error { return p.Provisioner.Prepare(raws...) } -func (p *PausedProvisioner) Provision(ui Ui, comm Communicator) error { - p.lock.Lock() - cancelCh := make(chan struct{}) - p.cancelCh = cancelCh - - // Setup the done channel, which is trigger when we're done - doneCh := make(chan struct{}) - defer close(doneCh) - p.doneCh = doneCh - p.lock.Unlock() - - defer func() { - p.lock.Lock() - defer p.lock.Unlock() - if p.cancelCh == cancelCh { - p.cancelCh = nil - } - if p.doneCh == doneCh { - p.doneCh = nil - } - }() +func (p *PausedProvisioner) Provision(ctx context.Context, ui Ui, comm Communicator) error { // Use a select to determine if we get cancelled during the wait ui.Say(fmt.Sprintf("Pausing %s before the next provisioner...", p.PauseBefore)) select { case <-time.After(p.PauseBefore): - case <-cancelCh: - return nil + case <-ctx.Done(): + return ctx.Err() } - provDoneCh := make(chan error, 1) - go p.provision(provDoneCh, ui, comm) - - select { - case err := <-provDoneCh: - return err - case <-cancelCh: - p.Provisioner.Cancel() - return <-provDoneCh - } -} - -func (p *PausedProvisioner) Cancel() { - var doneCh chan struct{} - - p.lock.Lock() - if p.cancelCh != nil { - close(p.cancelCh) - p.cancelCh = nil - } - if p.doneCh != nil { - doneCh = p.doneCh - } - p.lock.Unlock() - - <-doneCh -} - -func (p *PausedProvisioner) provision(result chan<- error, ui Ui, comm Communicator) { - result <- p.Provisioner.Provision(ui, comm) + return p.Provisioner.Provision(ctx, ui, comm) } // DebuggedProvisioner is a Provisioner implementation that waits until a key @@ -185,28 +103,7 @@ func (p *DebuggedProvisioner) Prepare(raws ...interface{}) error { return p.Provisioner.Prepare(raws...) } -func (p *DebuggedProvisioner) Provision(ui Ui, comm Communicator) error { - p.lock.Lock() - cancelCh := make(chan struct{}) - p.cancelCh = cancelCh - - // Setup the done channel, which is trigger when we're done - doneCh := make(chan struct{}) - defer close(doneCh) - p.doneCh = doneCh - p.lock.Unlock() - - defer func() { - p.lock.Lock() - defer p.lock.Unlock() - if p.cancelCh == cancelCh { - p.cancelCh = nil - } - if p.doneCh == doneCh { - p.doneCh = nil - } - }() - +func (p *DebuggedProvisioner) Provision(ctx context.Context, ui Ui, comm Communicator) error { // Use a select to determine if we get cancelled during the wait message := "Pausing before the next provisioner . Press enter to continue." @@ -222,38 +119,9 @@ func (p *DebuggedProvisioner) Provision(ui Ui, comm Communicator) error { select { case <-result: - case <-cancelCh: - return nil + case <-ctx.Done(): + return ctx.Err() } - provDoneCh := make(chan error, 1) - go p.provision(provDoneCh, ui, comm) - - select { - case err := <-provDoneCh: - return err - case <-cancelCh: - p.Provisioner.Cancel() - return <-provDoneCh - } -} - -func (p *DebuggedProvisioner) Cancel() { - var doneCh chan struct{} - - p.lock.Lock() - if p.cancelCh != nil { - close(p.cancelCh) - p.cancelCh = nil - } - if p.doneCh != nil { - doneCh = p.doneCh - } - p.lock.Unlock() - - <-doneCh -} - -func (p *DebuggedProvisioner) provision(result chan<- error, ui Ui, comm Communicator) { - result <- p.Provisioner.Provision(ui, comm) + return p.Provisioner.Provision(ctx, ui, comm) } diff --git a/packer/provisioner_mock.go b/packer/provisioner_mock.go index 0c7a25791..743a8e54a 100644 --- a/packer/provisioner_mock.go +++ b/packer/provisioner_mock.go @@ -1,5 +1,7 @@ package packer +import "context" + // MockProvisioner is an implementation of Provisioner that can be // used for tests. type MockProvisioner struct { @@ -19,7 +21,7 @@ func (t *MockProvisioner) Prepare(configs ...interface{}) error { return nil } -func (t *MockProvisioner) Provision(ui Ui, comm Communicator) error { +func (t *MockProvisioner) Provision(ctx context.Context, ui Ui, comm Communicator) error { t.ProvCalled = true t.ProvCommunicator = comm t.ProvUi = ui diff --git a/packer/provisioner_test.go b/packer/provisioner_test.go index 4d370ef39..9d386f41a 100644 --- a/packer/provisioner_test.go +++ b/packer/provisioner_test.go @@ -1,6 +1,7 @@ package packer import ( + "context" "sync" "testing" "time" @@ -134,7 +135,7 @@ func TestPausedProvisionerProvision(t *testing.T) { ui := testUi() comm := new(MockCommunicator) - prov.Provision(ui, comm) + prov.Provision(context.Background(), ui, comm) if !mock.ProvCalled { t.Fatal("prov should be called") } @@ -159,7 +160,7 @@ func TestPausedProvisionerProvision_waits(t *testing.T) { return nil } - go prov.Provision(testUi(), new(MockCommunicator)) + go prov.Provision(context.Background(), testUi(), new(MockCommunicator)) select { case <-time.After(10 * time.Millisecond): @@ -188,7 +189,7 @@ func TestPausedProvisionerCancel(t *testing.T) { } // Start provisioning and wait for it to start - go prov.Provision(testUi(), new(MockCommunicator)) + go prov.Provision(context.Background(), testUi(), new(MockCommunicator)) <-provCh // Cancel it @@ -226,7 +227,7 @@ func TestDebuggedProvisionerProvision(t *testing.T) { ui := testUi() comm := new(MockCommunicator) writeReader(ui, "\n") - prov.Provision(ui, comm) + prov.Provision(context.Background(), ui, comm) if !mock.ProvCalled { t.Fatal("prov should be called") } @@ -252,7 +253,7 @@ func TestDebuggedProvisionerCancel(t *testing.T) { } // Start provisioning and wait for it to start - go prov.Provision(testUi(), new(MockCommunicator)) + go prov.Provision(context.Background(), testUi(), new(MockCommunicator)) <-provCh // Cancel it diff --git a/packer/rpc/provisioner.go b/packer/rpc/provisioner.go index d8b3b3f66..1a66c6057 100644 --- a/packer/rpc/provisioner.go +++ b/packer/rpc/provisioner.go @@ -1,7 +1,7 @@ package rpc import ( - "log" + "context" "net/rpc" "github.com/hashicorp/packer/packer" @@ -34,7 +34,7 @@ func (p *provisioner) Prepare(configs ...interface{}) (err error) { return } -func (p *provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { +func (p *provisioner) Provision(ctx context.Context, ui packer.Ui, comm packer.Communicator) error { nextId := p.mux.NextId() server := newServerWithMux(p.mux, nextId) server.RegisterCommunicator(comm) @@ -44,32 +44,20 @@ func (p *provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { return p.client.Call("Provisioner.Provision", nextId, new(interface{})) } -func (p *provisioner) Cancel() { - err := p.client.Call("Provisioner.Cancel", new(interface{}), new(interface{})) - if err != nil { - log.Printf("Provisioner.Cancel err: %s", err) - } -} - -func (p *ProvisionerServer) Prepare(args *ProvisionerPrepareArgs, reply *interface{}) error { +func (p *ProvisionerServer) Prepare(_ context.Context, args *ProvisionerPrepareArgs, reply *interface{}) error { return p.p.Prepare(args.Configs...) } -func (p *ProvisionerServer) Provision(streamId uint32, reply *interface{}) error { +func (p *ProvisionerServer) Provision(ctx context.Context, streamId uint32, reply *interface{}) error { client, err := newClientWithMux(p.mux, streamId) if err != nil { return NewBasicError(err) } defer client.Close() - if err := p.p.Provision(client.Ui(), client.Communicator()); err != nil { + if err := p.p.Provision(ctx, client.Ui(), client.Communicator()); err != nil { return NewBasicError(err) } return nil } - -func (p *ProvisionerServer) Cancel(args *interface{}, reply *interface{}) error { - p.p.Cancel() - return nil -} diff --git a/packer/rpc/provisioner_test.go b/packer/rpc/provisioner_test.go index df310369b..1904bbd53 100644 --- a/packer/rpc/provisioner_test.go +++ b/packer/rpc/provisioner_test.go @@ -1,6 +1,7 @@ package rpc import ( + "context" "reflect" "testing" @@ -17,7 +18,7 @@ func TestProvisionerRPC(t *testing.T) { defer server.Close() server.RegisterProvisioner(p) pClient := client.Provisioner() - + ctx, cancel := context.WithCancel(context.Background()) // Test Prepare config := 42 pClient.Prepare(config) @@ -32,13 +33,15 @@ func TestProvisionerRPC(t *testing.T) { // Test Provision ui := &testUi{} comm := &packer.MockCommunicator{} - pClient.Provision(ui, comm) + if err := pClient.Provision(ctx, ui, comm); err != nil { + t.Fatalf("err: %v", err) + } if !p.ProvCalled { t.Fatal("should be called") } // Test Cancel - pClient.Cancel() + cancel() if !p.CancelCalled { t.Fatal("cancel should be called") } diff --git a/provisioner/ansible-local/communicator_mock.go b/provisioner/ansible-local/communicator_mock.go index 8ed59e9bb..b87aef663 100644 --- a/provisioner/ansible-local/communicator_mock.go +++ b/provisioner/ansible-local/communicator_mock.go @@ -1,9 +1,10 @@ package ansiblelocal import ( - "github.com/hashicorp/packer/packer" "io" "os" + + "github.com/hashicorp/packer/packer" ) type communicatorMock struct { diff --git a/provisioner/ansible-local/provisioner.go b/provisioner/ansible-local/provisioner.go index 17b28fa4b..62b37c82e 100644 --- a/provisioner/ansible-local/provisioner.go +++ b/provisioner/ansible-local/provisioner.go @@ -1,6 +1,7 @@ package ansiblelocal import ( + "context" "fmt" "os" "path/filepath" @@ -185,7 +186,7 @@ func (p *Provisioner) Prepare(raws ...interface{}) error { return nil } -func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { +func (p *Provisioner) Provision(ctx context.Context, ui packer.Ui, comm packer.Communicator) error { ui.Say("Provisioning with Ansible...") if len(p.config.PlaybookDir) > 0 { @@ -308,12 +309,6 @@ func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { return nil } -func (p *Provisioner) Cancel() { - // Just hard quit. It isn't a big deal if what we're doing keeps - // running on the other side. - os.Exit(0) -} - func (p *Provisioner) provisionPlaybookFiles(ui packer.Ui, comm packer.Communicator) error { var playbookDir string if p.config.PlaybookDir != "" { diff --git a/provisioner/ansible-local/provisioner_test.go b/provisioner/ansible-local/provisioner_test.go index 56c85053b..d2af8abfb 100644 --- a/provisioner/ansible-local/provisioner_test.go +++ b/provisioner/ansible-local/provisioner_test.go @@ -1,6 +1,7 @@ package ansiblelocal import ( + "context" "io/ioutil" "os" "path/filepath" @@ -133,7 +134,7 @@ func TestProvisionerProvision_PlaybookFiles(t *testing.T) { } comm := &communicatorMock{} - if err := p.Provision(new(packer.NoopUi), comm); err != nil { + if err := p.Provision(context.Background(), new(packer.NoopUi), comm); err != nil { t.Fatalf("err: %s", err) } @@ -167,7 +168,7 @@ func TestProvisionerProvision_PlaybookFilesWithPlaybookDir(t *testing.T) { } comm := &communicatorMock{} - if err := p.Provision(new(packer.NoopUi), comm); err != nil { + if err := p.Provision(context.Background(), new(packer.NoopUi), comm); err != nil { t.Fatalf("err: %s", err) } @@ -374,7 +375,7 @@ func testProvisionerProvisionDockerWithPlaybookFiles(t *testing.T, templateStrin } hook := &packer.DispatchHook{Mapping: hooks} - artifact, err := builder.Run(ui, hook) + artifact, err := builder.Run(context.Background(), ui, hook) if err != nil { t.Fatalf("Error running build %s", err) } diff --git a/provisioner/ansible/provisioner.go b/provisioner/ansible/provisioner.go index 9b28c2bbb..50e0702cb 100644 --- a/provisioner/ansible/provisioner.go +++ b/provisioner/ansible/provisioner.go @@ -3,6 +3,7 @@ package ansible import ( "bufio" "bytes" + "context" "crypto/rand" "crypto/rsa" "crypto/x509" @@ -195,7 +196,7 @@ func (p *Provisioner) getVersion() error { return nil } -func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { +func (p *Provisioner) Provision(ctx context.Context, ui packer.Ui, comm packer.Communicator) error { ui.Say("Provisioning with Ansible...") // Interpolate env vars to check for .WinRMPassword p.config.ctx.Data = &PassthroughTemplate{ diff --git a/provisioner/ansible/provisioner_test.go b/provisioner/ansible/provisioner_test.go index 7905f5600..27fa27dc7 100644 --- a/provisioner/ansible/provisioner_test.go +++ b/provisioner/ansible/provisioner_test.go @@ -4,6 +4,7 @@ package ansible import ( "bytes" + "context" "crypto/rand" "fmt" "io" @@ -348,7 +349,7 @@ func TestAnsibleLongMessages(t *testing.T) { Writer: new(bytes.Buffer), } - err = p.Provision(ui, comm) + err = p.Provision(context.Background(), ui, comm) if err != nil { t.Fatalf("err: %s", err) } diff --git a/provisioner/breakpoint/provisioner.go b/provisioner/breakpoint/provisioner.go index b64f97759..3787ce44d 100644 --- a/provisioner/breakpoint/provisioner.go +++ b/provisioner/breakpoint/provisioner.go @@ -1,6 +1,7 @@ package breakpoint import ( + "context" "fmt" "os" @@ -40,7 +41,7 @@ func (p *Provisioner) Prepare(raws ...interface{}) error { return nil } -func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { +func (p *Provisioner) Provision(ctx context.Context, ui packer.Ui, comm packer.Communicator) error { if p.config.Disable { if p.config.Note != "" { ui.Say(fmt.Sprintf( diff --git a/provisioner/chef-client/provisioner.go b/provisioner/chef-client/provisioner.go index 6d451ace4..24c963a36 100644 --- a/provisioner/chef-client/provisioner.go +++ b/provisioner/chef-client/provisioner.go @@ -5,6 +5,7 @@ package chefclient import ( "bytes" + "context" "encoding/json" "fmt" "io/ioutil" @@ -233,7 +234,7 @@ func (p *Provisioner) Prepare(raws ...interface{}) error { return nil } -func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { +func (p *Provisioner) Provision(ctx context.Context, ui packer.Ui, comm packer.Communicator) error { p.communicator = comm @@ -336,12 +337,6 @@ func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { return nil } -func (p *Provisioner) Cancel() { - // Just hard quit. It isn't a big deal if what we're doing keeps - // running on the other side. - os.Exit(0) -} - func (p *Provisioner) uploadFile(ui packer.Ui, comm packer.Communicator, remotePath string, localPath string) error { ui.Message(fmt.Sprintf("Uploading %s...", localPath)) diff --git a/provisioner/chef-solo/provisioner.go b/provisioner/chef-solo/provisioner.go index 9ca8bba2c..f2d96d885 100644 --- a/provisioner/chef-solo/provisioner.go +++ b/provisioner/chef-solo/provisioner.go @@ -5,6 +5,7 @@ package chefsolo import ( "bytes" + "context" "encoding/json" "fmt" "io/ioutil" @@ -227,7 +228,7 @@ func (p *Provisioner) Prepare(raws ...interface{}) error { return nil } -func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { +func (p *Provisioner) Provision(ctx context.Context, ui packer.Ui, comm packer.Communicator) error { ui.Say("Provisioning with chef-solo") if !p.config.SkipInstall { @@ -299,12 +300,6 @@ func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { return nil } -func (p *Provisioner) Cancel() { - // Just hard quit. It isn't a big deal if what we're doing keeps - // running on the other side. - os.Exit(0) -} - func (p *Provisioner) uploadDirectory(ui packer.Ui, comm packer.Communicator, dst string, src string) error { if err := p.createDir(ui, comm, dst); err != nil { return err diff --git a/provisioner/converge/provisioner.go b/provisioner/converge/provisioner.go index 16dddd73f..4a1f79cbd 100644 --- a/provisioner/converge/provisioner.go +++ b/provisioner/converge/provisioner.go @@ -5,6 +5,7 @@ package converge import ( "bytes" + "context" "errors" "fmt" @@ -105,7 +106,7 @@ func (p *Provisioner) Prepare(raws ...interface{}) error { } // Provision node somehow. TODO: actual docs -func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { +func (p *Provisioner) Provision(ctx context.Context, ui packer.Ui, comm packer.Communicator) error { ui.Say("Provisioning with Converge") // bootstrapping diff --git a/provisioner/file/provisioner.go b/provisioner/file/provisioner.go index 0ed980ee9..04358f98d 100644 --- a/provisioner/file/provisioner.go +++ b/provisioner/file/provisioner.go @@ -1,6 +1,7 @@ package file import ( + "context" "errors" "fmt" "io" @@ -89,7 +90,7 @@ func (p *Provisioner) Prepare(raws ...interface{}) error { return nil } -func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { +func (p *Provisioner) Provision(ctx context.Context, ui packer.Ui, comm packer.Communicator) error { if p.config.Direction == "download" { return p.ProvisionDownload(ui, comm) } else { @@ -186,9 +187,3 @@ func (p *Provisioner) ProvisionUpload(ui packer.Ui, comm packer.Communicator) er } return nil } - -func (p *Provisioner) Cancel() { - // Just hard quit. It isn't a big deal if what we're doing keeps - // running on the other side. - os.Exit(0) -} diff --git a/provisioner/file/provisioner_test.go b/provisioner/file/provisioner_test.go index 732bba3ff..46bccca6d 100644 --- a/provisioner/file/provisioner_test.go +++ b/provisioner/file/provisioner_test.go @@ -2,6 +2,7 @@ package file import ( "bytes" + "context" "io/ioutil" "os" "path/filepath" @@ -126,7 +127,7 @@ func TestProvisionerProvision_SendsFile(t *testing.T) { Writer: b, } comm := &packer.MockCommunicator{} - err = p.Provision(ui, comm) + err = p.Provision(context.Background(), ui, comm) if err != nil { t.Fatalf("should successfully provision: %s", err) } diff --git a/provisioner/inspec/provisioner.go b/provisioner/inspec/provisioner.go index ae8048f05..15c975c20 100644 --- a/provisioner/inspec/provisioner.go +++ b/provisioner/inspec/provisioner.go @@ -3,6 +3,7 @@ package inspec import ( "bufio" "bytes" + "context" "crypto/rand" "crypto/rsa" "crypto/x509" @@ -183,7 +184,7 @@ func (p *Provisioner) getVersion() error { return nil } -func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { +func (p *Provisioner) Provision(ctx context.Context, ui packer.Ui, comm packer.Communicator) error { ui.Say("Provisioning with Inspec...") for i, envVar := range p.config.InspecEnvVars { diff --git a/provisioner/powershell/provisioner.go b/provisioner/powershell/provisioner.go index 57b3c4849..97a460928 100644 --- a/provisioner/powershell/provisioner.go +++ b/provisioner/powershell/provisioner.go @@ -4,6 +4,7 @@ package powershell import ( "bufio" + "context" "errors" "fmt" "log" @@ -215,7 +216,7 @@ func extractScript(p *Provisioner) (string, error) { return temp.Name(), nil } -func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { +func (p *Provisioner) Provision(ctx context.Context, ui packer.Ui, comm packer.Communicator) error { ui.Say(fmt.Sprintf("Provisioning with Powershell...")) p.communicator = comm diff --git a/provisioner/powershell/provisioner_test.go b/provisioner/powershell/provisioner_test.go index 7e7ab2dcd..c2252ac12 100644 --- a/provisioner/powershell/provisioner_test.go +++ b/provisioner/powershell/provisioner_test.go @@ -2,6 +2,7 @@ package powershell import ( "bytes" + "context" "errors" "fmt" "io/ioutil" @@ -356,7 +357,7 @@ func TestProvisionerProvision_ValidExitCodes(t *testing.T) { comm := new(packer.MockCommunicator) comm.StartExitStatus = 200 p.Prepare(config) - err := p.Provision(ui, comm) + err := p.Provision(context.Background(), ui, comm) if err != nil { t.Fatal("should not have error") } @@ -379,7 +380,7 @@ func TestProvisionerProvision_InvalidExitCodes(t *testing.T) { comm := new(packer.MockCommunicator) comm.StartExitStatus = 201 // Invalid! p.Prepare(config) - err := p.Provision(ui, comm) + err := p.Provision(context.Background(), ui, comm) if err == nil { t.Fatal("should have error") } @@ -400,7 +401,7 @@ func TestProvisionerProvision_Inline(t *testing.T) { p.config.PackerBuilderType = "iso" comm := new(packer.MockCommunicator) p.Prepare(config) - err := p.Provision(ui, comm) + err := p.Provision(context.Background(), ui, comm) if err != nil { t.Fatal("should not have error") } @@ -420,7 +421,7 @@ func TestProvisionerProvision_Inline(t *testing.T) { config["remote_path"] = "c:/Windows/Temp/inlineScript.ps1" p.Prepare(config) - err = p.Provision(ui, comm) + err = p.Provision(context.Background(), ui, comm) if err != nil { t.Fatal("should not have error") } @@ -449,7 +450,7 @@ func TestProvisionerProvision_Scripts(t *testing.T) { p := new(Provisioner) comm := new(packer.MockCommunicator) p.Prepare(config) - err := p.Provision(ui, comm) + err := p.Provision(context.Background(), ui, comm) if err != nil { t.Fatal("should not have error") } @@ -485,7 +486,7 @@ func TestProvisionerProvision_ScriptsWithEnvVars(t *testing.T) { p := new(Provisioner) comm := new(packer.MockCommunicator) p.Prepare(config) - err := p.Provision(ui, comm) + err := p.Provision(context.Background(), ui, comm) if err != nil { t.Fatal("should not have error") } diff --git a/provisioner/puppet-masterless/provisioner.go b/provisioner/puppet-masterless/provisioner.go index 8368e37d3..576926944 100644 --- a/provisioner/puppet-masterless/provisioner.go +++ b/provisioner/puppet-masterless/provisioner.go @@ -4,6 +4,7 @@ package puppetmasterless import ( + "context" "fmt" "os" "path/filepath" @@ -256,7 +257,7 @@ func (p *Provisioner) Prepare(raws ...interface{}) error { return nil } -func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { +func (p *Provisioner) Provision(ctx context.Context, ui packer.Ui, comm packer.Communicator) error { ui.Say("Provisioning with Puppet...") p.communicator = comm ui.Message("Creating Puppet staging directory...") @@ -364,12 +365,6 @@ func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { return nil } -func (p *Provisioner) Cancel() { - // Just hard quit. It isn't a big deal if what we're doing keeps - // running on the other side. - os.Exit(0) -} - func (p *Provisioner) uploadHieraConfig(ui packer.Ui, comm packer.Communicator) (string, error) { ui.Message("Uploading hiera configuration...") f, err := os.Open(p.config.HieraConfigPath) diff --git a/provisioner/puppet-masterless/provisioner_test.go b/provisioner/puppet-masterless/provisioner_test.go index c8adce89e..f8a0ae782 100644 --- a/provisioner/puppet-masterless/provisioner_test.go +++ b/provisioner/puppet-masterless/provisioner_test.go @@ -1,6 +1,7 @@ package puppetmasterless import ( + "context" "fmt" "io/ioutil" "log" @@ -493,7 +494,7 @@ func TestProvisionerProvision_extraArguments(t *testing.T) { t.Fatalf("err: %s", err) } - err = p.Provision(ui, comm) + err = p.Provision(context.Background(), ui, comm) if err != nil { t.Fatalf("err: %s", err) } @@ -513,7 +514,7 @@ func TestProvisionerProvision_extraArguments(t *testing.T) { t.Fatalf("err: %s", err) } - err = p.Provision(ui, comm) + err = p.Provision(context.Background(), ui, comm) if err != nil { t.Fatalf("err: %s", err) } diff --git a/provisioner/puppet-server/provisioner.go b/provisioner/puppet-server/provisioner.go index 686bdf847..6d2a7b848 100644 --- a/provisioner/puppet-server/provisioner.go +++ b/provisioner/puppet-server/provisioner.go @@ -3,6 +3,7 @@ package puppetserver import ( + "context" "fmt" "os" "path/filepath" @@ -226,7 +227,7 @@ func (p *Provisioner) Prepare(raws ...interface{}) error { return nil } -func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { +func (p *Provisioner) Provision(ctx context.Context, ui packer.Ui, comm packer.Communicator) error { ui.Say("Provisioning with Puppet...") p.communicator = comm ui.Message("Creating Puppet staging directory...") @@ -317,12 +318,6 @@ func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { return nil } -func (p *Provisioner) Cancel() { - // Just hard quit. It isn't a big deal if what we're doing keeps - // running on the other side. - os.Exit(0) -} - func (p *Provisioner) createDir(ui packer.Ui, comm packer.Communicator, dir string) error { ui.Message(fmt.Sprintf("Creating directory: %s", dir)) diff --git a/provisioner/salt-masterless/provisioner.go b/provisioner/salt-masterless/provisioner.go index 76d2ff2d3..870b022e7 100644 --- a/provisioner/salt-masterless/provisioner.go +++ b/provisioner/salt-masterless/provisioner.go @@ -4,6 +4,7 @@ package saltmasterless import ( "bytes" + "context" "errors" "fmt" "os" @@ -219,7 +220,7 @@ func (p *Provisioner) Prepare(raws ...interface{}) error { return nil } -func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { +func (p *Provisioner) Provision(ctx context.Context, ui packer.Ui, comm packer.Communicator) error { var err error var src, dst string @@ -352,12 +353,6 @@ func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { return nil } -func (p *Provisioner) Cancel() { - // Just hard quit. It isn't a big deal if what we're doing keeps - // running on the other side. - os.Exit(0) -} - // Prepends sudo to supplied command if config says to func (p *Provisioner) sudo(cmd string) string { if p.config.DisableSudo || (p.config.GuestOSType == provisioner.WindowsOSType) { diff --git a/provisioner/shell-local/provisioner.go b/provisioner/shell-local/provisioner.go index 16c3806e4..9223f3c52 100644 --- a/provisioner/shell-local/provisioner.go +++ b/provisioner/shell-local/provisioner.go @@ -1,6 +1,8 @@ package shell import ( + "context" + sl "github.com/hashicorp/packer/common/shell-local" "github.com/hashicorp/packer/packer" ) @@ -23,7 +25,7 @@ func (p *Provisioner) Prepare(raws ...interface{}) error { return nil } -func (p *Provisioner) Provision(ui packer.Ui, _ packer.Communicator) error { +func (p *Provisioner) Provision(ctx context.Context, ui packer.Ui, _ packer.Communicator) error { _, retErr := sl.Run(ui, &p.config) if retErr != nil { return retErr diff --git a/provisioner/shell/provisioner.go b/provisioner/shell/provisioner.go index 1fa14d9d5..cac990510 100644 --- a/provisioner/shell/provisioner.go +++ b/provisioner/shell/provisioner.go @@ -4,6 +4,7 @@ package shell import ( "bufio" + "context" "errors" "fmt" "io" @@ -182,7 +183,7 @@ func (p *Provisioner) Prepare(raws ...interface{}) error { return nil } -func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { +func (p *Provisioner) Provision(ctx context.Context, ui packer.Ui, comm packer.Communicator) error { scripts := make([]string, len(p.config.Scripts)) copy(scripts, p.config.Scripts) @@ -399,12 +400,6 @@ func (p *Provisioner) cleanupRemoteFile(path string, comm packer.Communicator) e return nil } -func (p *Provisioner) Cancel() { - // Just hard quit. It isn't a big deal if what we're doing keeps - // running on the other side. - os.Exit(0) -} - // retryable will retry the given function over and over until a // non-error is returned. func (p *Provisioner) retryable(f func() error) error { diff --git a/provisioner/windows-restart/provisioner.go b/provisioner/windows-restart/provisioner.go index 0593024c4..75523d125 100644 --- a/provisioner/windows-restart/provisioner.go +++ b/provisioner/windows-restart/provisioner.go @@ -2,6 +2,7 @@ package restart import ( "bytes" + "context" "fmt" "io" @@ -92,7 +93,7 @@ func (p *Provisioner) Prepare(raws ...interface{}) error { return nil } -func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { +func (p *Provisioner) Provision(ctx context.Context, ui packer.Ui, comm packer.Communicator) error { p.cancelLock.Lock() p.cancel = make(chan struct{}) p.cancelLock.Unlock() @@ -116,10 +117,10 @@ func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { return fmt.Errorf("Restart script exited with non-zero exit status: %d", cmd.ExitStatus) } - return waitForRestart(p, comm) + return waitForRestart(ctx, p, comm) } -var waitForRestart = func(p *Provisioner, comm packer.Communicator) error { +var waitForRestart = func(ctx context.Context, p *Provisioner, comm packer.Communicator) error { ui := p.ui ui.Say("Waiting for machine to restart...") waitDone := make(chan bool, 1) @@ -165,7 +166,7 @@ var waitForRestart = func(p *Provisioner, comm packer.Communicator) error { go func() { log.Printf("Waiting for machine to become available...") - err = waitForCommunicator(p) + err = waitForCommunicator(ctx, p) waitDone <- true }() @@ -199,7 +200,7 @@ WaitLoop: } -var waitForCommunicator = func(p *Provisioner) error { +var waitForCommunicator = func(ctx context.Context, p *Provisioner) error { runCustomRestartCheck := true if p.config.RestartCheckCommand == DefaultRestartCheckCommand { runCustomRestartCheck = false @@ -213,7 +214,7 @@ var waitForCommunicator = func(p *Provisioner) error { cmdRestartCheck.Command) for { select { - case <-p.cancel: + case <-ctx.Done(): log.Println("Communicator wait canceled, exiting loop") return fmt.Errorf("Communicator wait canceled") case <-time.After(retryableSleep): @@ -286,16 +287,6 @@ var waitForCommunicator = func(p *Provisioner) error { return nil } -func (p *Provisioner) Cancel() { - log.Printf("Received interrupt Cancel()") - - p.cancelLock.Lock() - defer p.cancelLock.Unlock() - if p.cancel != nil { - close(p.cancel) - } -} - // retryable will retry the given function over and over until a // non-error is returned. func (p *Provisioner) retryable(f func() error) error { diff --git a/provisioner/windows-restart/provisioner_test.go b/provisioner/windows-restart/provisioner_test.go index d4d42ba4f..b6039b1b8 100644 --- a/provisioner/windows-restart/provisioner_test.go +++ b/provisioner/windows-restart/provisioner_test.go @@ -2,6 +2,7 @@ package restart import ( "bytes" + "context" "errors" "fmt" "testing" @@ -97,14 +98,14 @@ func TestProvisionerProvision_Success(t *testing.T) { comm := new(packer.MockCommunicator) p.Prepare(config) waitForCommunicatorOld := waitForCommunicator - waitForCommunicator = func(p *Provisioner) error { + waitForCommunicator = func(context.Context, *Provisioner) error { return nil } waitForRestartOld := waitForRestart - waitForRestart = func(p *Provisioner, comm packer.Communicator) error { + waitForRestart = func(context.Context, *Provisioner, packer.Communicator) error { return nil } - err := p.Provision(ui, comm) + err := p.Provision(context.Background(), ui, comm) if err != nil { t.Fatal("should not have error") } @@ -133,14 +134,14 @@ func TestProvisionerProvision_CustomCommand(t *testing.T) { comm := new(packer.MockCommunicator) p.Prepare(config) waitForCommunicatorOld := waitForCommunicator - waitForCommunicator = func(p *Provisioner) error { + waitForCommunicator = func(context.Context, *Provisioner) error { return nil } waitForRestartOld := waitForRestart - waitForRestart = func(p *Provisioner, comm packer.Communicator) error { + waitForRestart = func(context.Context, *Provisioner, packer.Communicator) error { return nil } - err := p.Provision(ui, comm) + err := p.Provision(context.Background(), ui, comm) if err != nil { t.Fatal("should not have error") } @@ -163,7 +164,7 @@ func TestProvisionerProvision_RestartCommandFail(t *testing.T) { comm.StartExitStatus = 1 p.Prepare(config) - err := p.Provision(ui, comm) + err := p.Provision(context.Background(), ui, comm) if err == nil { t.Fatal("should have error") } @@ -179,10 +180,10 @@ func TestProvisionerProvision_WaitForRestartFail(t *testing.T) { comm := new(packer.MockCommunicator) p.Prepare(config) waitForCommunicatorOld := waitForCommunicator - waitForCommunicator = func(p *Provisioner) error { + waitForCommunicator = func(context.Context, *Provisioner) error { return fmt.Errorf("Machine did not restart properly") } - err := p.Provision(ui, comm) + err := p.Provision(context.Background(), ui, comm) if err == nil { t.Fatal("should have error") } @@ -206,7 +207,7 @@ func TestProvision_waitForRestartTimeout(t *testing.T) { waitContinue := make(chan bool) // Block until cancel comes through - waitForCommunicator = func(p *Provisioner) error { + waitForCommunicator = func(context.Context, *Provisioner) error { for { select { case <-waitDone: @@ -216,7 +217,7 @@ func TestProvision_waitForRestartTimeout(t *testing.T) { } go func() { - err = p.Provision(ui, comm) + err = p.Provision(context.Background(), ui, comm) waitDone <- true }() <-waitContinue @@ -245,7 +246,7 @@ func TestProvision_waitForCommunicator(t *testing.T) { comm.StartStdout = "WIN-V4CEJ7MC5SN restarted." comm.StartExitStatus = 1 p.Prepare(config) - err := waitForCommunicator(p) + err := waitForCommunicator(context.Background(), p) if err != nil { t.Fatalf("should not have error, got: %s", err.Error()) @@ -274,6 +275,8 @@ func TestProvision_waitForCommunicatorWithCancel(t *testing.T) { p.cancel = make(chan struct{}) var err error + ctx, cancel := context.WithCancel(context.Background()) + comm.StartStderr = "WinRM terminated" comm.StartExitStatus = 1 // Always fail p.Prepare(config) @@ -285,14 +288,14 @@ func TestProvision_waitForCommunicatorWithCancel(t *testing.T) { waitDone := make(chan bool) go func() { waitStart <- true - err = waitForCommunicator(p) + err = waitForCommunicator(ctx, p) waitDone <- true }() go func() { time.Sleep(10 * time.Millisecond) <-waitStart - p.Cancel() + cancel() }() <-waitDone @@ -347,25 +350,27 @@ func TestProvision_Cancel(t *testing.T) { waitDone := make(chan bool) // Block until cancel comes through - waitForCommunicator = func(p *Provisioner) error { + waitForCommunicator = func(ctx context.Context, p *Provisioner) error { waitStart <- true + panic("this test is incorrect") for { select { case <-p.cancel: } } } + ctx, cancel := context.WithCancel(context.Background()) // Create two go routines to provision and cancel in parallel // Provision will block until cancel happens go func() { - err = p.Provision(ui, comm) + err = p.Provision(ctx, ui, comm) waitDone <- true }() go func() { <-waitStart - p.Cancel() + cancel() }() <-waitDone diff --git a/provisioner/windows-shell/provisioner.go b/provisioner/windows-shell/provisioner.go index 6d37e2514..2fa91d3a0 100644 --- a/provisioner/windows-shell/provisioner.go +++ b/provisioner/windows-shell/provisioner.go @@ -4,6 +4,7 @@ package shell import ( "bufio" + "context" "errors" "fmt" "log" @@ -159,7 +160,7 @@ func extractScript(p *Provisioner) (string, error) { return temp.Name(), nil } -func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { +func (p *Provisioner) Provision(ctx context.Context, ui packer.Ui, comm packer.Communicator) error { ui.Say(fmt.Sprintf("Provisioning with windows-shell...")) scripts := make([]string, len(p.config.Scripts)) copy(scripts, p.config.Scripts) @@ -230,12 +231,6 @@ func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { return nil } -func (p *Provisioner) Cancel() { - // Just hard quit. It isn't a big deal if what we're doing keeps - // running on the other side. - os.Exit(0) -} - // retryable will retry the given function over and over until a // non-error is returned. func (p *Provisioner) retryable(f func() error) error { diff --git a/provisioner/windows-shell/provisioner_test.go b/provisioner/windows-shell/provisioner_test.go index 5f4149478..d5500d522 100644 --- a/provisioner/windows-shell/provisioner_test.go +++ b/provisioner/windows-shell/provisioner_test.go @@ -2,6 +2,7 @@ package shell import ( "bytes" + "context" "errors" "fmt" "io/ioutil" @@ -294,7 +295,7 @@ func TestProvisionerProvision_Inline(t *testing.T) { p.config.PackerBuilderType = "iso" comm := new(packer.MockCommunicator) p.Prepare(config) - err := p.Provision(ui, comm) + err := p.Provision(context.Background(), ui, comm) if err != nil { t.Fatal("should not have error") } @@ -313,7 +314,7 @@ func TestProvisionerProvision_Inline(t *testing.T) { config["remote_path"] = "c:/Windows/Temp/inlineScript.bat" p.Prepare(config) - err = p.Provision(ui, comm) + err = p.Provision(context.Background(), ui, comm) if err != nil { t.Fatal("should not have error") } @@ -344,7 +345,7 @@ func TestProvisionerProvision_Scripts(t *testing.T) { p := new(Provisioner) comm := new(packer.MockCommunicator) p.Prepare(config) - err = p.Provision(ui, comm) + err = p.Provision(context.Background(), ui, comm) if err != nil { t.Fatal("should not have error") } @@ -383,7 +384,7 @@ func TestProvisionerProvision_ScriptsWithEnvVars(t *testing.T) { p := new(Provisioner) comm := new(packer.MockCommunicator) p.Prepare(config) - err = p.Provision(ui, comm) + err = p.Provision(context.Background(), ui, comm) if err != nil { t.Fatal("should not have error") }