change hooks to be passed a context for cancellation

we have to as it is what calls our provisioners
This commit is contained in:
Adrien Delorme 2019-03-22 14:50:33 +01:00
parent c7ce4d598e
commit 829851fc8a
12 changed files with 90 additions and 107 deletions

View File

@ -22,7 +22,7 @@ type StepProvision struct {
Comm packer.Communicator Comm packer.Communicator
} }
func (s *StepProvision) Run(_ context.Context, state multistep.StateBag) multistep.StepAction { func (s *StepProvision) Run(ctx context.Context, state multistep.StateBag) multistep.StepAction {
comm := s.Comm comm := s.Comm
if comm == nil { if comm == nil {
raw, ok := state.Get("communicator").(packer.Communicator) raw, ok := state.Get("communicator").(packer.Communicator)
@ -38,7 +38,7 @@ func (s *StepProvision) Run(_ context.Context, state multistep.StateBag) multist
log.Println("Running the provision hook") log.Println("Running the provision hook")
errCh := make(chan error, 1) errCh := make(chan error, 1)
go func() { go func() {
errCh <- hook.Run(packer.HookProvision, ui, comm, nil) errCh <- hook.Run(ctx, packer.HookProvision, ui, comm, nil)
}() }()
for { for {
@ -53,7 +53,6 @@ func (s *StepProvision) Run(_ context.Context, state multistep.StateBag) multist
case <-time.After(1 * time.Second): case <-time.After(1 * time.Second):
if _, ok := state.GetOk(multistep.StateCancelled); ok { if _, ok := state.GetOk(multistep.StateCancelled); ok {
log.Println("Cancelling provisioning due to interrupt...") log.Println("Cancelling provisioning due to interrupt...")
hook.Cancel()
return multistep.ActionHalt return multistep.ActionHalt
} }
} }

View File

@ -192,7 +192,7 @@ func TestBuild_Run(t *testing.T) {
// Verify hooks are dispatchable // Verify hooks are dispatchable
dispatchHook := builder.RunHook dispatchHook := builder.RunHook
dispatchHook.Run("foo", nil, nil, 42) dispatchHook.Run(ctx, "foo", nil, nil, 42)
hook := build.hooks["foo"][0].(*MockHook) hook := build.hooks["foo"][0].(*MockHook)
if !hook.RunCalled { if !hook.RunCalled {
@ -203,7 +203,7 @@ func TestBuild_Run(t *testing.T) {
} }
// Verify provisioners run // Verify provisioners run
dispatchHook.Run(HookProvision, nil, new(MockCommunicator), 42) dispatchHook.Run(ctx, HookProvision, nil, new(MockCommunicator), 42)
prov := build.provisioners[0].provisioner.(*MockProvisioner) prov := build.provisioners[0].provisioner.(*MockProvisioner)
if !prov.ProvCalled { if !prov.ProvCalled {
t.Fatal("should be called") t.Fatal("should be called")

View File

@ -1,5 +1,7 @@
package packer package packer
import "context"
// Implementers of Builder are responsible for actually building images // Implementers of Builder are responsible for actually building images
// on some platform given some configuration. // on some platform given some configuration.
// //
@ -28,9 +30,5 @@ type Builder interface {
Prepare(...interface{}) ([]string, error) Prepare(...interface{}) ([]string, error)
// Run is where the actual build should take place. It takes a Build and a Ui. // Run is where the actual build should take place. It takes a Build and a Ui.
Run(ui Ui, hook Hook) (Artifact, error) Run(context.Context, Ui, Hook) (Artifact, error)
// Cancel cancels a possibly running Builder. This should block until
// the builder actually cancels and cleans up after itself.
Cancel()
} }

View File

@ -1,6 +1,7 @@
package packer package packer
import ( import (
"context"
"errors" "errors"
) )
@ -27,7 +28,7 @@ func (tb *MockBuilder) Prepare(config ...interface{}) ([]string, error) {
return tb.PrepareWarnings, nil return tb.PrepareWarnings, nil
} }
func (tb *MockBuilder) Run(ui Ui, h Hook) (Artifact, error) { func (tb *MockBuilder) Run(ctx context.Context, ui Ui, h Hook) (Artifact, error) {
tb.RunCalled = true tb.RunCalled = true
tb.RunHook = h tb.RunHook = h
tb.RunUi = ui tb.RunUi = ui
@ -41,7 +42,7 @@ func (tb *MockBuilder) Run(ui Ui, h Hook) (Artifact, error) {
} }
if h != nil { if h != nil {
if err := h.Run(HookProvision, ui, new(MockCommunicator), nil); err != nil { if err := h.Run(ctx, HookProvision, ui, new(MockCommunicator), nil); err != nil {
return nil, err return nil, err
} }
} }
@ -50,7 +51,3 @@ func (tb *MockBuilder) Run(ui Ui, h Hook) (Artifact, error) {
IdValue: tb.ArtifactId, IdValue: tb.ArtifactId,
}, nil }, nil
} }
func (tb *MockBuilder) Cancel() {
tb.CancelCalled = true
}

View File

@ -1,7 +1,7 @@
package packer package packer
import ( import (
"sync" "context"
) )
// This is the hook that should be fired for provisioners to run. // This is the hook that should be fired for provisioners to run.
@ -21,33 +21,18 @@ const HookProvision = "packer_provision"
// must be race-free. Cancel should attempt to cancel the hook in the // must be race-free. Cancel should attempt to cancel the hook in the
// quickest, safest way possible. // quickest, safest way possible.
type Hook interface { type Hook interface {
Run(string, Ui, Communicator, interface{}) error Run(context.Context, string, Ui, Communicator, interface{}) error
Cancel()
} }
// A Hook implementation that dispatches based on an internal mapping. // A Hook implementation that dispatches based on an internal mapping.
type DispatchHook struct { type DispatchHook struct {
Mapping map[string][]Hook Mapping map[string][]Hook
l sync.Mutex
cancelled bool
runningHook Hook
} }
// Runs the hook with the given name by dispatching it to the proper // Runs the hook with the given name by dispatching it to the proper
// hooks if a mapping exists. If a mapping doesn't exist, then nothing // hooks if a mapping exists. If a mapping doesn't exist, then nothing
// happens. // happens.
func (h *DispatchHook) Run(name string, ui Ui, comm Communicator, data interface{}) error { func (h *DispatchHook) Run(ctx context.Context, name string, ui Ui, comm Communicator, data interface{}) error {
h.l.Lock()
h.cancelled = false
h.l.Unlock()
// Make sure when we exit that we reset the running hook.
defer func() {
h.l.Lock()
defer h.l.Unlock()
h.runningHook = nil
}()
hooks, ok := h.Mapping[name] hooks, ok := h.Mapping[name]
if !ok { if !ok {
@ -56,32 +41,14 @@ func (h *DispatchHook) Run(name string, ui Ui, comm Communicator, data interface
} }
for _, hook := range hooks { for _, hook := range hooks {
h.l.Lock() if err := ctx.Err(); err != nil {
if h.cancelled { return err
h.l.Unlock()
return nil
} }
h.runningHook = hook if err := hook.Run(ctx, name, ui, comm, data); err != nil {
h.l.Unlock()
if err := hook.Run(name, ui, comm, data); err != nil {
return err return err
} }
} }
return nil return nil
} }
// Cancels all the hooks that are currently in-flight, if any. This will
// block until the hooks are all cancelled.
func (h *DispatchHook) Cancel() {
h.l.Lock()
defer h.l.Unlock()
if h.runningHook != nil {
h.runningHook.Cancel()
}
h.cancelled = true
}

View File

@ -1,5 +1,10 @@
package packer package packer
import (
"context"
"time"
)
// MockHook is an implementation of Hook that can be used for tests. // MockHook is an implementation of Hook that can be used for tests.
type MockHook struct { type MockHook struct {
RunFunc func() error RunFunc func() error
@ -12,7 +17,16 @@ type MockHook struct {
CancelCalled bool CancelCalled bool
} }
func (t *MockHook) Run(name string, ui Ui, comm Communicator, data interface{}) error { func (t *MockHook) Run(ctx context.Context, name string, ui Ui, comm Communicator, data interface{}) error {
go func() {
select {
case <-time.After(2 * time.Minute):
case <-ctx.Done():
t.CancelCalled = true
}
}()
t.RunCalled = true t.RunCalled = true
t.RunComm = comm t.RunComm = comm
t.RunData = data t.RunData = data
@ -25,7 +39,3 @@ func (t *MockHook) Run(name string, ui Ui, comm Communicator, data interface{})
return t.RunFunc() return t.RunFunc()
} }
func (t *MockHook) Cancel() {
t.CancelCalled = true
}

View File

@ -1,6 +1,7 @@
package packer package packer
import ( import (
"context"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -15,7 +16,15 @@ type CancelHook struct {
Cancelled bool Cancelled bool
} }
func (h *CancelHook) Run(string, Ui, Communicator, interface{}) error { func (h *CancelHook) Run(ctx context.Context, _ string, _ Ui, _ Communicator, _ interface{}) error {
go func() {
select {
case <-time.After(2 * time.Minute):
case <-ctx.Done():
h.cancel()
}
}()
h.Lock() h.Lock()
h.cancelCh = make(chan struct{}) h.cancelCh = make(chan struct{})
h.doneCh = make(chan struct{}) h.doneCh = make(chan struct{})
@ -32,7 +41,7 @@ func (h *CancelHook) Run(string, Ui, Communicator, interface{}) error {
return nil return nil
} }
func (h *CancelHook) Cancel() { func (h *CancelHook) cancel() {
h.Lock() h.Lock()
close(h.cancelCh) close(h.cancelCh)
h.Unlock() h.Unlock()
@ -47,7 +56,7 @@ func TestDispatchHook_Implements(t *testing.T) {
func TestDispatchHook_Run_NoHooks(t *testing.T) { func TestDispatchHook_Run_NoHooks(t *testing.T) {
// Just make sure nothing blows up // Just make sure nothing blows up
dh := &DispatchHook{} dh := &DispatchHook{}
dh.Run("foo", nil, nil, nil) dh.Run(context.Background(), "foo", nil, nil, nil)
} }
func TestDispatchHook_Run(t *testing.T) { func TestDispatchHook_Run(t *testing.T) {
@ -56,7 +65,7 @@ func TestDispatchHook_Run(t *testing.T) {
mapping := make(map[string][]Hook) mapping := make(map[string][]Hook)
mapping["foo"] = []Hook{hook} mapping["foo"] = []Hook{hook}
dh := &DispatchHook{Mapping: mapping} dh := &DispatchHook{Mapping: mapping}
dh.Run("foo", nil, nil, 42) dh.Run(context.Background(), "foo", nil, nil, 42)
if !hook.RunCalled { if !hook.RunCalled {
t.Fatal("should be called") t.Fatal("should be called")
@ -77,10 +86,10 @@ func TestDispatchHook_cancel(t *testing.T) {
"foo": {hook}, "foo": {hook},
}, },
} }
ctx, cancel := context.WithCancel(context.Background())
go dh.Run("foo", nil, nil, 42) go dh.Run(ctx, "foo", nil, nil, 42)
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
dh.Cancel() cancel()
if !hook.Cancelled { if !hook.Cancelled {
t.Fatal("hook should've cancelled") t.Fatal("hook should've cancelled")

View File

@ -1,6 +1,7 @@
package plugin package plugin
import ( import (
"context"
"log" "log"
"github.com/hashicorp/packer/packer" "github.com/hashicorp/packer/packer"
@ -11,22 +12,13 @@ type cmdHook struct {
client *Client client *Client
} }
func (c *cmdHook) Run(name string, ui packer.Ui, comm packer.Communicator, data interface{}) error { func (c *cmdHook) Run(ctx context.Context, name string, ui packer.Ui, comm packer.Communicator, data interface{}) error {
defer func() { defer func() {
r := recover() r := recover()
c.checkExit(r, nil) c.checkExit(r, nil)
}() }()
return c.hook.Run(name, ui, comm, data) return c.hook.Run(ctx, name, ui, comm, data)
}
func (c *cmdHook) Cancel() {
defer func() {
r := recover()
c.checkExit(r, nil)
}()
c.hook.Cancel()
} }
func (c *cmdHook) checkExit(p interface{}, cb func()) { func (c *cmdHook) checkExit(p interface{}, cb func()) {

View File

@ -1,6 +1,9 @@
package packer package packer
import "context" import (
"context"
"time"
)
// MockProvisioner is an implementation of Provisioner that can be // MockProvisioner is an implementation of Provisioner that can be
// used for tests. // used for tests.
@ -22,6 +25,14 @@ func (t *MockProvisioner) Prepare(configs ...interface{}) error {
} }
func (t *MockProvisioner) Provision(ctx context.Context, ui Ui, comm Communicator) error { func (t *MockProvisioner) Provision(ctx context.Context, ui Ui, comm Communicator) error {
go func() {
select {
case <-time.After(2 * time.Minute):
case <-ctx.Done():
t.CancelCalled = true
}
}()
t.ProvCalled = true t.ProvCalled = true
t.ProvCommunicator = comm t.ProvCommunicator = comm
t.ProvUi = ui t.ProvUi = ui
@ -33,10 +44,6 @@ func (t *MockProvisioner) Provision(ctx context.Context, ui Ui, comm Communicato
return t.ProvFunc() return t.ProvFunc()
} }
func (t *MockProvisioner) Cancel() {
t.CancelCalled = true
}
func (t *MockProvisioner) Communicator() Communicator { func (t *MockProvisioner) Communicator() Communicator {
return t.ProvCommunicator return t.ProvCommunicator
} }

View File

@ -30,7 +30,7 @@ func TestProvisionHook(t *testing.T) {
}, },
} }
hook.Run("foo", ui, comm, data) hook.Run(context.Background(), "foo", ui, comm, data)
if !pA.ProvCalled { if !pA.ProvCalled {
t.Error("provision should be called on pA") t.Error("provision should be called on pA")
@ -56,7 +56,7 @@ func TestProvisionHook_nilComm(t *testing.T) {
}, },
} }
err := hook.Run("foo", ui, comm, data) err := hook.Run(context.Background(), "foo", ui, comm, data)
if err == nil { if err == nil {
t.Fatal("should error") t.Fatal("should error")
} }
@ -83,16 +83,17 @@ func TestProvisionHook_cancel(t *testing.T) {
{p, nil, ""}, {p, nil, ""},
}, },
} }
ctx, cancel := context.WithCancel(context.Background())
finished := make(chan struct{}) finished := make(chan struct{})
go func() { go func() {
hook.Run("foo", nil, new(MockCommunicator), nil) hook.Run(ctx, "foo", nil, new(MockCommunicator), nil)
close(finished) close(finished)
}() }()
// Cancel it while it is running // Cancel it while it is running
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
hook.Cancel() cancel()
lock.Lock() lock.Lock()
order = append(order, "cancel") order = append(order, "cancel")
lock.Unlock() lock.Unlock()
@ -187,13 +188,15 @@ func TestPausedProvisionerCancel(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
return nil return nil
} }
ctx, cancel := context.WithCancel(context.Background())
// Start provisioning and wait for it to start // Start provisioning and wait for it to start
go prov.Provision(context.Background(), testUi(), new(MockCommunicator)) go func() {
<-provCh <-provCh
cancel()
}()
// Cancel it prov.Provision(ctx, testUi(), new(MockCommunicator))
prov.Cancel()
if !mock.CancelCalled { if !mock.CancelCalled {
t.Fatal("cancel should be called") t.Fatal("cancel should be called")
} }
@ -251,13 +254,15 @@ func TestDebuggedProvisionerCancel(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
return nil return nil
} }
ctx, cancel := context.WithCancel(context.Background())
// Start provisioning and wait for it to start // Start provisioning and wait for it to start
go prov.Provision(context.Background(), testUi(), new(MockCommunicator)) go func() {
<-provCh <-provCh
cancel()
}()
// Cancel it prov.Provision(ctx, testUi(), new(MockCommunicator))
prov.Cancel()
if !mock.CancelCalled { if !mock.CancelCalled {
t.Fatal("cancel should be called") t.Fatal("cancel should be called")
} }

View File

@ -1,6 +1,7 @@
package rpc package rpc
import ( import (
"context"
"log" "log"
"net/rpc" "net/rpc"
@ -27,7 +28,7 @@ type HookRunArgs struct {
StreamId uint32 StreamId uint32
} }
func (h *hook) Run(name string, ui packer.Ui, comm packer.Communicator, data interface{}) error { func (h *hook) Run(ctx context.Context, name string, ui packer.Ui, comm packer.Communicator, data interface{}) error {
nextId := h.mux.NextId() nextId := h.mux.NextId()
server := newServerWithMux(h.mux, nextId) server := newServerWithMux(h.mux, nextId)
server.RegisterCommunicator(comm) server.RegisterCommunicator(comm)
@ -50,22 +51,17 @@ func (h *hook) Cancel() {
} }
} }
func (h *HookServer) Run(args *HookRunArgs, reply *interface{}) error { func (h *HookServer) Run(ctx context.Context, args *HookRunArgs, reply *interface{}) error {
client, err := newClientWithMux(h.mux, args.StreamId) client, err := newClientWithMux(h.mux, args.StreamId)
if err != nil { if err != nil {
return NewBasicError(err) return NewBasicError(err)
} }
defer client.Close() defer client.Close()
if err := h.hook.Run(args.Name, client.Ui(), client.Communicator(), args.Data); err != nil { if err := h.hook.Run(ctx, args.Name, client.Ui(), client.Communicator(), args.Data); err != nil {
return NewBasicError(err) return NewBasicError(err)
} }
*reply = nil *reply = nil
return nil return nil
} }
func (h *HookServer) Cancel(args *interface{}, reply *interface{}) error {
h.hook.Cancel()
return nil
}

View File

@ -1,6 +1,7 @@
package rpc package rpc
import ( import (
"context"
"reflect" "reflect"
"sync" "sync"
"testing" "testing"
@ -12,6 +13,7 @@ import (
func TestHookRPC(t *testing.T) { func TestHookRPC(t *testing.T) {
// Create the UI to test // Create the UI to test
h := new(packer.MockHook) h := new(packer.MockHook)
ctx, cancel := context.WithCancel(context.Background())
// Serve // Serve
client, server := testClientServer(t) client, server := testClientServer(t)
@ -22,13 +24,13 @@ func TestHookRPC(t *testing.T) {
// Test Run // Test Run
ui := &testUi{} ui := &testUi{}
hClient.Run("foo", ui, nil, 42) hClient.Run(ctx, "foo", ui, nil, 42)
if !h.RunCalled { if !h.RunCalled {
t.Fatal("should be called") t.Fatal("should be called")
} }
// Test Cancel // Test Cancel
hClient.Cancel() cancel()
if !h.CancelCalled { if !h.CancelCalled {
t.Fatal("should be called") t.Fatal("should be called")
} }
@ -52,6 +54,7 @@ func TestHook_cancelWhileRun(t *testing.T) {
return nil return nil
}, },
} }
ctx, cancel := context.WithCancel(context.Background())
// Serve // Serve
client, server := testClientServer(t) client, server := testClientServer(t)
@ -63,13 +66,13 @@ func TestHook_cancelWhileRun(t *testing.T) {
// Start the run // Start the run
finished := make(chan struct{}) finished := make(chan struct{})
go func() { go func() {
hClient.Run("foo", nil, nil, nil) hClient.Run(ctx, "foo", nil, nil, nil)
close(finished) close(finished)
}() }()
// Cancel it pretty quickly. // Cancel it pretty quickly.
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
hClient.Cancel() cancel()
finishLock.Lock() finishLock.Lock()
finishOrder = append(finishOrder, "cancel") finishOrder = append(finishOrder, "cancel")