[Builder|Build|PostProcessor|Provisioner|Hook]Server: context cancel using a RPC cancel method

This commit is contained in:
Adrien Delorme 2019-03-27 12:29:09 +01:00
parent 14048b1e11
commit 82c4b76639
25 changed files with 442 additions and 356 deletions

View File

@ -50,6 +50,9 @@ func (s *StepProvision) Run(ctx context.Context, state multistep.StateBag) multi
} }
return multistep.ActionContinue return multistep.ActionContinue
case <-ctx.Done():
log.Printf("Cancelling provisioning due to context cancellation: %s", ctx.Err())
return multistep.ActionHalt
case <-time.After(1 * time.Second): case <-time.After(1 * time.Second):
if _, ok := state.GetOk(multistep.StateCancelled); ok { if _, ok := state.GetOk(multistep.StateCancelled); ok {
log.Println("Cancelling provisioning due to interrupt...") log.Println("Cancelling provisioning due to interrupt...")

View File

@ -20,14 +20,11 @@ type BasicRunner struct {
// modified. // modified.
Steps []Step Steps []Step
cancel context.CancelFunc l sync.Mutex
doneCh chan struct{} state runState
state runState
l sync.Mutex
} }
func (b *BasicRunner) Run(ctx context.Context, state StateBag) { func (b *BasicRunner) Run(ctx context.Context, state StateBag) {
ctx, cancel := context.WithCancel(ctx)
b.l.Lock() b.l.Lock()
if b.state != stateIdle { if b.state != stateIdle {
@ -35,15 +32,11 @@ func (b *BasicRunner) Run(ctx context.Context, state StateBag) {
} }
doneCh := make(chan struct{}) doneCh := make(chan struct{})
b.cancel = cancel
b.doneCh = doneCh
b.state = stateRunning b.state = stateRunning
b.l.Unlock() b.l.Unlock()
defer func() { defer func() {
b.l.Lock() b.l.Lock()
b.cancel = nil
b.doneCh = nil
b.state = stateIdle b.state = stateIdle
close(doneCh) close(doneCh)
b.l.Unlock() b.l.Unlock()
@ -54,14 +47,16 @@ func (b *BasicRunner) Run(ctx context.Context, state StateBag) {
go func() { go func() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
// Flag cancel and wait for finish
state.Put(StateCancelled, true) state.Put(StateCancelled, true)
<-doneCh
case <-doneCh: case <-doneCh:
} }
}() }()
for _, step := range b.Steps { for _, step := range b.Steps {
if err := ctx.Err(); err != nil {
state.Put(StateCancelled, true)
break
}
// We also check for cancellation here since we can't be sure // We also check for cancellation here since we can't be sure
// the goroutine that is running to set it actually ran. // the goroutine that is running to set it actually ran.
if runState(atomic.LoadInt32((*int32)(&b.state))) == stateCancelling { if runState(atomic.LoadInt32((*int32)(&b.state))) == stateCancelling {

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"reflect" "reflect"
"testing" "testing"
"time"
) )
func TestBasicRunner_ImplRunner(t *testing.T) { func TestBasicRunner_ImplRunner(t *testing.T) {
@ -99,39 +98,47 @@ func TestBasicRunner_Run_Run(t *testing.T) {
} }
func TestBasicRunner_Cancel(t *testing.T) { func TestBasicRunner_Cancel(t *testing.T) {
ch := make(chan chan bool)
data := new(BasicStateBag)
stepA := &TestStepAcc{Data: "a"}
stepB := &TestStepAcc{Data: "b"}
stepInt := &TestStepSync{ch}
stepC := &TestStepAcc{Data: "c"}
r := &BasicRunner{Steps: []Step{stepA, stepB, stepInt, stepC}} topCtx, topCtxCancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(context.Background()) checkCancelled := func(data StateBag) {
cancelled := data.Get(StateCancelled).(bool)
go r.Run(ctx, data) if !cancelled {
t.Fatal("state should be cancelled")
// Wait until we reach the sync point
responseCh := <-ch
// Cancel then continue chain
cancelCh := make(chan bool)
go func() {
cancel()
cancelCh <- true
}()
for {
if _, ok := data.GetOk(StateCancelled); ok {
responseCh <- true
break
} }
time.Sleep(10 * time.Millisecond)
} }
<-cancelCh data := new(BasicStateBag)
r := &BasicRunner{}
r.Steps = []Step{
&TestStepAcc{Data: "a"},
&TestStepAcc{Data: "b"},
TestStepFn{
run: func(ctx context.Context, sb StateBag) StepAction {
return ActionContinue
},
cleanup: checkCancelled,
},
TestStepFn{
run: func(ctx context.Context, sb StateBag) StepAction {
topCtxCancel()
<-ctx.Done()
return ActionContinue
},
cleanup: checkCancelled,
},
TestStepFn{
run: func(context.Context, StateBag) StepAction {
t.Fatal("I should not be called")
return ActionContinue
},
cleanup: func(StateBag) {
t.Fatal("I should not be called")
},
},
}
r.Run(topCtx, data)
// Test run data // Test run data
expected := []string{"a", "b"} expected := []string{"a", "b"}
@ -148,10 +155,8 @@ func TestBasicRunner_Cancel(t *testing.T) {
} }
// Test that it says it is cancelled // Test that it says it is cancelled
cancelled := data.Get(StateCancelled).(bool) checkCancelled(data)
if !cancelled {
t.Errorf("not cancelled")
}
} }
func TestBasicRunner_Cancel_Special(t *testing.T) { func TestBasicRunner_Cancel_Special(t *testing.T) {

View File

@ -78,41 +78,64 @@ func TestDebugRunner_Run_Run(t *testing.T) {
t.Errorf("Was able to run an already running DebugRunner") t.Errorf("Was able to run an already running DebugRunner")
} }
type TestStepFn struct {
run func(context.Context, StateBag) StepAction
cleanup func(StateBag)
}
var _ Step = TestStepFn{}
func (fn TestStepFn) Run(ctx context.Context, sb StateBag) StepAction {
return fn.run(ctx, sb)
}
func (fn TestStepFn) Cleanup(sb StateBag) {
if fn.cleanup != nil {
fn.cleanup(sb)
}
}
func TestDebugRunner_Cancel(t *testing.T) { func TestDebugRunner_Cancel(t *testing.T) {
ch := make(chan chan bool)
data := new(BasicStateBag)
stepA := &TestStepAcc{Data: "a"}
stepB := &TestStepAcc{Data: "b"}
stepInt := &TestStepSync{ch}
stepC := &TestStepAcc{Data: "c"}
r := &DebugRunner{} topCtx, topCtxCancel := context.WithCancel(context.Background())
r.Steps = []Step{stepA, stepB, stepInt, stepC}
ctx, cancel := context.WithCancel(context.Background()) checkCancelled := func(data StateBag) {
cancelled := data.Get(StateCancelled).(bool)
go r.Run(ctx, data) if !cancelled {
t.Fatal("state should be cancelled")
// Wait until we reach the sync point
responseCh := <-ch
// Cancel then continue chain
cancelCh := make(chan bool)
go func() {
cancel()
cancelCh <- true
}()
for {
if _, ok := data.GetOk(StateCancelled); ok {
responseCh <- true
break
} }
time.Sleep(10 * time.Millisecond)
} }
<-cancelCh data := new(BasicStateBag)
r := &DebugRunner{}
r.Steps = []Step{
&TestStepAcc{Data: "a"},
&TestStepAcc{Data: "b"},
TestStepFn{
run: func(ctx context.Context, sb StateBag) StepAction {
return ActionContinue
},
cleanup: checkCancelled,
},
TestStepFn{
run: func(ctx context.Context, sb StateBag) StepAction {
topCtxCancel()
<-ctx.Done()
return ActionContinue
},
cleanup: checkCancelled,
},
TestStepFn{
run: func(context.Context, StateBag) StepAction {
t.Fatal("I should not be called")
return ActionContinue
},
cleanup: func(StateBag) {
t.Fatal("I should not be called")
},
},
}
r.Run(topCtx, data)
// Test run data // Test run data
expected := []string{"a", "b"} expected := []string{"a", "b"}
@ -129,8 +152,11 @@ func TestDebugRunner_Cancel(t *testing.T) {
} }
// Test that it says it is cancelled // Test that it says it is cancelled
cancelled := data.Get(StateCancelled).(bool) cancelled, ok := data.GetOk(StateCancelled)
if !cancelled { if !ok {
t.Fatal("could not get state cancelled")
}
if !cancelled.(bool) {
t.Errorf("not cancelled") t.Errorf("not cancelled")
} }
} }

View File

@ -58,7 +58,7 @@ func (s TestStepSync) Run(context.Context, StateBag) StepAction {
return ActionContinue return ActionContinue
} }
func (s TestStepSync) Cleanup(StateBag) {} func (s TestStepSync) Cleanup(StateBag) { close(s.Ch) }
func (s TestStepWaitForever) Run(context.Context, StateBag) StepAction { func (s TestStepWaitForever) Run(context.Context, StateBag) StepAction {
select {} select {}

View File

@ -371,12 +371,18 @@ func TestBuild_RunBeforePrepare(t *testing.T) {
func TestBuild_Cancel(t *testing.T) { func TestBuild_Cancel(t *testing.T) {
build := testBuild() build := testBuild()
ctx, cancel := context.WithCancel(context.Background()) build.Prepare()
cancel()
build.Run(ctx, nil) topCtx, topCtxCancel := context.WithCancel(context.Background())
builder := build.builder.(*MockBuilder) builder := build.builder.(*MockBuilder)
if !builder.CancelCalled {
t.Fatal("cancel should be called") builder.RunFn = func(ctx context.Context) {
topCtxCancel()
}
_, err := build.Run(topCtx, testUi())
if err == nil {
t.Fatal("build should err")
} }
} }

View File

@ -20,6 +20,7 @@ type MockBuilder struct {
RunHook Hook RunHook Hook
RunUi Ui RunUi Ui
CancelCalled bool CancelCalled bool
RunFn func(ctx context.Context)
} }
func (tb *MockBuilder) Prepare(config ...interface{}) ([]string, error) { func (tb *MockBuilder) Prepare(config ...interface{}) ([]string, error) {
@ -40,6 +41,9 @@ func (tb *MockBuilder) Run(ctx context.Context, ui Ui, h Hook) (Artifact, error)
if tb.RunNilResult { if tb.RunNilResult {
return nil, nil return nil, nil
} }
if tb.RunFn != nil {
tb.RunFn(ctx)
}
if h != nil { if h != nil {
if err := h.Run(ctx, HookProvision, ui, new(MockCommunicator), nil); err != nil { if err := h.Run(ctx, HookProvision, ui, new(MockCommunicator), nil); err != nil {

View File

@ -2,31 +2,21 @@ package packer
import ( import (
"context" "context"
"time"
) )
// MockHook is an implementation of Hook that can be used for tests. // MockHook is an implementation of Hook that can be used for tests.
type MockHook struct { type MockHook struct {
RunFunc func() error RunFunc func(context.Context) error
RunCalled bool RunCalled bool
RunComm Communicator RunComm Communicator
RunData interface{} RunData interface{}
RunName string RunName string
RunUi Ui RunUi Ui
CancelCalled bool
} }
func (t *MockHook) Run(ctx context.Context, name string, ui Ui, comm Communicator, data interface{}) error { func (t *MockHook) Run(ctx context.Context, name string, ui Ui, comm Communicator, data interface{}) error {
go func() {
select {
case <-time.After(2 * time.Minute):
case <-ctx.Done():
t.CancelCalled = true
}
}()
t.RunCalled = true t.RunCalled = true
t.RunComm = comm t.RunComm = comm
t.RunData = data t.RunData = data
@ -37,5 +27,5 @@ func (t *MockHook) Run(ctx context.Context, name string, ui Ui, comm Communicato
return nil return nil
} }
return t.RunFunc() return t.RunFunc(ctx)
} }

View File

@ -2,53 +2,9 @@ package packer
import ( import (
"context" "context"
"sync"
"testing" "testing"
"time"
) )
// A helper Hook implementation for testing cancels.
type CancelHook struct {
sync.Mutex
cancelCh chan struct{}
doneCh chan struct{}
Cancelled bool
}
func (h *CancelHook) Run(ctx context.Context, _ string, _ Ui, _ Communicator, _ interface{}) error {
go func() {
select {
case <-time.After(2 * time.Minute):
case <-ctx.Done():
h.cancel()
}
}()
h.Lock()
h.cancelCh = make(chan struct{})
h.doneCh = make(chan struct{})
h.Unlock()
defer close(h.doneCh)
select {
case <-h.cancelCh:
h.Cancelled = true
case <-time.After(1 * time.Second):
}
return nil
}
func (h *CancelHook) cancel() {
h.Lock()
close(h.cancelCh)
h.Unlock()
<-h.doneCh
}
func TestDispatchHook_Implements(t *testing.T) { func TestDispatchHook_Implements(t *testing.T) {
var _ Hook = new(DispatchHook) var _ Hook = new(DispatchHook)
} }
@ -78,20 +34,36 @@ func TestDispatchHook_Run(t *testing.T) {
} }
} }
// A helper Hook implementation for testing cancels.
// Run will wait indetinitelly until ctx is cancelled.
type CancelHook struct {
cancel func()
}
func (h *CancelHook) Run(ctx context.Context, _ string, _ Ui, _ Communicator, _ interface{}) error {
h.cancel()
<-ctx.Done()
return ctx.Err()
}
func TestDispatchHook_cancel(t *testing.T) { func TestDispatchHook_cancel(t *testing.T) {
hook := new(CancelHook)
cancelHook := new(CancelHook)
dh := &DispatchHook{ dh := &DispatchHook{
Mapping: map[string][]Hook{ Mapping: map[string][]Hook{
"foo": {hook}, "foo": {cancelHook},
}, },
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
go dh.Run(ctx, "foo", nil, nil, 42) cancelHook.cancel = cancel
time.Sleep(100 * time.Millisecond)
cancel()
if !hook.Cancelled { errchan := make(chan error)
t.Fatal("hook should've cancelled") go func() {
errchan <- dh.Run(ctx, "foo", nil, nil, 42)
}()
if err := <-errchan; err == nil {
t.Fatal("hook should've errored")
} }
} }

View File

@ -101,7 +101,7 @@ func Server() (*packrpc.Server, error) {
// Serve a single connection // Serve a single connection
log.Println("Serving a plugin connection...") log.Println("Serving a plugin connection...")
return packrpc.NewServer(conn), nil return packrpc.NewServer(conn)
} }
func serverListener(minPort, maxPort int64) (net.Listener, error) { func serverListener(minPort, maxPort int64) (net.Listener, error) {

View File

@ -2,20 +2,18 @@ package packer
import ( import (
"context" "context"
"time"
) )
// MockProvisioner is an implementation of Provisioner that can be // MockProvisioner is an implementation of Provisioner that can be
// used for tests. // used for tests.
type MockProvisioner struct { type MockProvisioner struct {
ProvFunc func() error ProvFunc func(context.Context) error
PrepCalled bool PrepCalled bool
PrepConfigs []interface{} PrepConfigs []interface{}
ProvCalled bool ProvCalled bool
ProvCommunicator Communicator ProvCommunicator Communicator
ProvUi Ui ProvUi Ui
CancelCalled bool
} }
func (t *MockProvisioner) Prepare(configs ...interface{}) error { func (t *MockProvisioner) Prepare(configs ...interface{}) error {
@ -25,14 +23,6 @@ func (t *MockProvisioner) Prepare(configs ...interface{}) error {
} }
func (t *MockProvisioner) Provision(ctx context.Context, ui Ui, comm Communicator) error { func (t *MockProvisioner) Provision(ctx context.Context, ui Ui, comm Communicator) error {
go func() {
select {
case <-time.After(2 * time.Minute):
case <-ctx.Done():
t.CancelCalled = true
}
}()
t.ProvCalled = true t.ProvCalled = true
t.ProvCommunicator = comm t.ProvCommunicator = comm
t.ProvUi = ui t.ProvUi = ui
@ -41,7 +31,7 @@ func (t *MockProvisioner) Provision(ctx context.Context, ui Ui, comm Communicato
return nil return nil
} }
return t.ProvFunc() return t.ProvFunc(ctx)
} }
func (t *MockProvisioner) Communicator() Communicator { func (t *MockProvisioner) Communicator() Communicator {

View File

@ -2,7 +2,6 @@ package packer
import ( import (
"context" "context"
"sync"
"testing" "testing"
"time" "time"
) )
@ -63,18 +62,13 @@ func TestProvisionHook_nilComm(t *testing.T) {
} }
func TestProvisionHook_cancel(t *testing.T) { func TestProvisionHook_cancel(t *testing.T) {
var lock sync.Mutex topCtx, topCtxCancel := context.WithCancel(context.Background())
order := make([]string, 0, 2)
p := &MockProvisioner{ p := &MockProvisioner{
ProvFunc: func() error { ProvFunc: func(ctx context.Context) error {
time.Sleep(100 * time.Millisecond) topCtxCancel()
<-ctx.Done()
lock.Lock() return ctx.Err()
defer lock.Unlock()
order = append(order, "prov")
return nil
}, },
} }
@ -83,27 +77,10 @@ func TestProvisionHook_cancel(t *testing.T) {
{p, nil, ""}, {p, nil, ""},
}, },
} }
ctx, cancel := context.WithCancel(context.Background())
finished := make(chan struct{}) err := hook.Run(topCtx, "foo", nil, new(MockCommunicator), nil)
go func() { if err == nil {
hook.Run(ctx, "foo", nil, new(MockCommunicator), nil) t.Fatal("should have err")
close(finished)
}()
// Cancel it while it is running
time.Sleep(10 * time.Millisecond)
cancel()
lock.Lock()
order = append(order, "cancel")
lock.Unlock()
// Wait
<-finished
// Verify order
if len(order) != 2 || order[0] != "cancel" || order[1] != "prov" {
t.Fatalf("bad: %#v", order)
} }
} }
@ -156,7 +133,7 @@ func TestPausedProvisionerProvision_waits(t *testing.T) {
} }
dataCh := make(chan struct{}) dataCh := make(chan struct{})
mock.ProvFunc = func() error { mock.ProvFunc = func(context.Context) error {
close(dataCh) close(dataCh)
return nil return nil
} }
@ -177,28 +154,22 @@ func TestPausedProvisionerProvision_waits(t *testing.T) {
} }
func TestPausedProvisionerCancel(t *testing.T) { func TestPausedProvisionerCancel(t *testing.T) {
topCtx, cancelTopCtx := context.WithCancel(context.Background())
mock := new(MockProvisioner) mock := new(MockProvisioner)
prov := &PausedProvisioner{ prov := &PausedProvisioner{
Provisioner: mock, Provisioner: mock,
} }
provCh := make(chan struct{}) mock.ProvFunc = func(ctx context.Context) error {
mock.ProvFunc = func() error { cancelTopCtx()
close(provCh) <-ctx.Done()
time.Sleep(10 * time.Millisecond) return ctx.Err()
return nil
} }
ctx, cancel := context.WithCancel(context.Background())
// Start provisioning and wait for it to start err := prov.Provision(topCtx, testUi(), new(MockCommunicator))
go func() { if err == nil {
<-provCh t.Fatal("should have err")
cancel()
}()
prov.Provision(ctx, testUi(), new(MockCommunicator))
if !mock.CancelCalled {
t.Fatal("cancel should be called")
} }
} }
@ -243,27 +214,21 @@ func TestDebuggedProvisionerProvision(t *testing.T) {
} }
func TestDebuggedProvisionerCancel(t *testing.T) { func TestDebuggedProvisionerCancel(t *testing.T) {
topCtx, topCtxCancel := context.WithCancel(context.Background())
mock := new(MockProvisioner) mock := new(MockProvisioner)
prov := &DebuggedProvisioner{ prov := &DebuggedProvisioner{
Provisioner: mock, Provisioner: mock,
} }
provCh := make(chan struct{}) mock.ProvFunc = func(ctx context.Context) error {
mock.ProvFunc = func() error { topCtxCancel()
close(provCh) <-ctx.Done()
time.Sleep(10 * time.Millisecond) return ctx.Err()
return nil
} }
ctx, cancel := context.WithCancel(context.Background())
// Start provisioning and wait for it to start err := prov.Provision(topCtx, testUi(), new(MockCommunicator))
go func() { if err == nil {
<-provCh t.Fatal("should have error")
cancel()
}()
prov.Provision(ctx, testUi(), new(MockCommunicator))
if !mock.CancelCalled {
t.Fatal("cancel should be called")
} }
} }

View File

@ -2,6 +2,7 @@ package rpc
import ( import (
"context" "context"
"log"
"net/rpc" "net/rpc"
"github.com/hashicorp/packer/packer" "github.com/hashicorp/packer/packer"
@ -17,6 +18,9 @@ type build struct {
// BuildServer wraps a packer.Build implementation and makes it exportable // BuildServer wraps a packer.Build implementation and makes it exportable
// as part of a Golang RPC server. // as part of a Golang RPC server.
type BuildServer struct { type BuildServer struct {
context context.Context
contextCancel func()
build packer.Build build packer.Build
mux *muxBroker mux *muxBroker
} }
@ -50,6 +54,19 @@ func (b *build) Run(ctx context.Context, ui packer.Ui) ([]packer.Artifact, error
server.RegisterUi(ui) server.RegisterUi(ui)
go server.Serve() go server.Serve()
done := make(chan interface{})
defer close(done)
go func() {
select {
case <-ctx.Done():
log.Printf("Cancelling build after context cancellation %v", ctx.Err())
if err := b.client.Call("Build.Cancel", new(interface{}), new(interface{})); err != nil {
log.Printf("Error cancelling builder: %s", err)
}
case <-done:
}
}()
var result []uint32 var result []uint32
if err := b.client.Call("Build.Run", nextId, &result); err != nil { if err := b.client.Call("Build.Run", nextId, &result); err != nil {
return nil, err return nil, err
@ -106,14 +123,18 @@ func (b *BuildServer) Prepare(args *interface{}, resp *BuildPrepareResponse) err
return nil return nil
} }
func (b *BuildServer) Run(ctx context.Context, streamId uint32, reply *[]uint32) error { func (b *BuildServer) Run(streamId uint32, reply *[]uint32) error {
if b.context == nil {
b.context, b.contextCancel = context.WithCancel(context.Background())
}
client, err := newClientWithMux(b.mux, streamId) client, err := newClientWithMux(b.mux, streamId)
if err != nil { if err != nil {
return NewBasicError(err) return NewBasicError(err)
} }
defer client.Close() defer client.Close()
artifacts, err := b.build.Run(ctx, client.Ui()) artifacts, err := b.build.Run(b.context, client.Ui())
if err != nil { if err != nil {
return NewBasicError(err) return NewBasicError(err)
} }
@ -147,6 +168,8 @@ func (b *BuildServer) SetOnError(val *string, reply *interface{}) error {
} }
func (b *BuildServer) Cancel(args *interface{}, reply *interface{}) error { func (b *BuildServer) Cancel(args *interface{}, reply *interface{}) error {
panic("cancel !") if b.contextCancel != nil {
b.contextCancel()
}
return nil return nil
} }

View File

@ -15,6 +15,7 @@ type testBuild struct {
nameCalled bool nameCalled bool
prepareCalled bool prepareCalled bool
prepareWarnings []string prepareWarnings []string
runFn func(context.Context)
runCalled bool runCalled bool
runUi packer.Ui runUi packer.Ui
setDebugCalled bool setDebugCalled bool
@ -36,13 +37,13 @@ func (b *testBuild) Prepare() ([]string, error) {
} }
func (b *testBuild) Run(ctx context.Context, ui packer.Ui) ([]packer.Artifact, error) { func (b *testBuild) Run(ctx context.Context, ui packer.Ui) ([]packer.Artifact, error) {
go func() {
<-ctx.Done()
b.cancelCalled = true
}()
b.runCalled = true b.runCalled = true
b.runUi = ui b.runUi = ui
if b.runFn != nil {
b.runFn(ctx)
}
if b.errRunResult { if b.errRunResult {
return nil, errors.New("foo") return nil, errors.New("foo")
} else { } else {
@ -62,10 +63,6 @@ func (b *testBuild) SetOnError(string) {
b.setOnErrorCalled = true b.setOnErrorCalled = true
} }
func (b *testBuild) Cancel() {
b.cancelCalled = true
}
func TestBuild(t *testing.T) { func TestBuild(t *testing.T) {
b := new(testBuild) b := new(testBuild)
client, server := testClientServer(t) client, server := testClientServer(t)
@ -74,7 +71,7 @@ func TestBuild(t *testing.T) {
server.RegisterBuild(b) server.RegisterBuild(b)
bClient := client.Build() bClient := client.Build()
ctx, cancel := context.WithCancel(context.Background()) ctx := context.Background()
// Test Name // Test Name
bClient.Name() bClient.Name()
@ -131,12 +128,33 @@ func TestBuild(t *testing.T) {
if !b.setOnErrorCalled { if !b.setOnErrorCalled {
t.Fatal("should be called") t.Fatal("should be called")
} }
}
// Test Cancel func TestBuild_cancel(t *testing.T) {
cancel() topCtx, cancelTopCtx := context.WithCancel(context.Background())
if !b.cancelCalled {
t.Fatal("should be called") b := new(testBuild)
done := make(chan interface{})
b.runFn = func(ctx context.Context) {
cancelTopCtx()
<-ctx.Done()
close(done)
} }
client, server := testClientServer(t)
defer client.Close()
defer server.Close()
server.RegisterBuild(b)
bClient := client.Build()
bClient.Prepare()
ui := new(testUi)
bClient.Run(topCtx, ui)
// if context cancellation is not propagated, this will timeout
<-done
} }
func TestBuildPrepare_Warnings(t *testing.T) { func TestBuildPrepare_Warnings(t *testing.T) {

View File

@ -2,6 +2,7 @@ package rpc
import ( import (
"context" "context"
"log"
"net/rpc" "net/rpc"
"github.com/hashicorp/packer/packer" "github.com/hashicorp/packer/packer"
@ -17,6 +18,9 @@ type builder struct {
// BuilderServer wraps a packer.Builder implementation and makes it exportable // BuilderServer wraps a packer.Builder implementation and makes it exportable
// as part of a Golang RPC server. // as part of a Golang RPC server.
type BuilderServer struct { type BuilderServer struct {
context context.Context
contextCancel func()
builder packer.Builder builder packer.Builder
mux *muxBroker mux *muxBroker
} }
@ -51,7 +55,21 @@ func (b *builder) Run(ctx context.Context, ui packer.Ui, hook packer.Hook) (pack
server.RegisterUi(ui) server.RegisterUi(ui)
go server.Serve() go server.Serve()
done := make(chan interface{})
defer close(done)
go func() {
select {
case <-ctx.Done():
log.Printf("Cancelling builder after context cancellation %v", ctx.Err())
if err := b.client.Call("Builder.Cancel", new(interface{}), new(interface{})); err != nil {
log.Printf("Error cancelling builder: %s", err)
}
case <-done:
}
}()
var responseId uint32 var responseId uint32
if err := b.client.Call("Builder.Run", nextId, &responseId); err != nil { if err := b.client.Call("Builder.Run", nextId, &responseId); err != nil {
return nil, err return nil, err
} }
@ -77,14 +95,18 @@ func (b *BuilderServer) Prepare(args *BuilderPrepareArgs, reply *BuilderPrepareR
return nil return nil
} }
func (b *BuilderServer) Run(ctx context.Context, streamId uint32, reply *uint32) error { func (b *BuilderServer) Run(streamId uint32, reply *uint32) error {
client, err := newClientWithMux(b.mux, streamId) client, err := newClientWithMux(b.mux, streamId)
if err != nil { if err != nil {
return NewBasicError(err) return NewBasicError(err)
} }
defer client.Close() defer client.Close()
artifact, err := b.builder.Run(ctx, client.Ui(), client.Hook()) if b.context == nil {
b.context, b.contextCancel = context.WithCancel(context.Background())
}
artifact, err := b.builder.Run(b.context, client.Ui(), client.Hook())
if err != nil { if err != nil {
return NewBasicError(err) return NewBasicError(err)
} }
@ -92,11 +114,16 @@ func (b *BuilderServer) Run(ctx context.Context, streamId uint32, reply *uint32)
*reply = 0 *reply = 0
if artifact != nil { if artifact != nil {
streamId = b.mux.NextId() streamId = b.mux.NextId()
server := newServerWithMux(b.mux, streamId) artifactServer := newServerWithMux(b.mux, streamId)
server.RegisterArtifact(artifact) artifactServer.RegisterArtifact(artifact)
go server.Serve() go artifactServer.Serve()
*reply = streamId *reply = streamId
} }
return nil return nil
} }
func (b *BuilderServer) Cancel(args *interface{}, reply *interface{}) error {
b.contextCancel()
return nil
}

View File

@ -127,19 +127,26 @@ func TestBuilderRun_ErrResult(t *testing.T) {
} }
func TestBuilderCancel(t *testing.T) { func TestBuilderCancel(t *testing.T) {
topCtx, topCtxCancel := context.WithCancel(context.Background())
// var runCtx context.Context
b := new(packer.MockBuilder) b := new(packer.MockBuilder)
cancelled := false
b.RunFn = func(ctx context.Context) {
topCtxCancel()
<-ctx.Done()
cancelled = true
}
client, server := testClientServer(t) client, server := testClientServer(t)
defer client.Close() defer client.Close()
defer server.Close() defer server.Close()
server.RegisterBuilder(b) server.RegisterBuilder(b)
bClient := client.Builder() bClient := client.Builder()
ctx, cancel := context.WithCancel(context.Background()) bClient.Run(topCtx, new(testUi), new(packer.MockHook))
cancel()
bClient.Run(ctx, nil, nil)
if !b.CancelCalled { if !cancelled {
t.Fatal("cancel should be called") t.Fatal("context should have been cancelled")
} }
} }

View File

@ -35,7 +35,7 @@ func testConn(t *testing.T) (net.Conn, net.Conn) {
func testClientServer(t *testing.T) (*Client, *Server) { func testClientServer(t *testing.T) (*Client, *Server) {
clientConn, serverConn := testConn(t) clientConn, serverConn := testConn(t)
server := NewServer(serverConn) server, _ := NewServer(serverConn)
go server.Serve() go server.Serve()
client, err := NewClient(clientConn) client, err := NewClient(clientConn)

View File

@ -18,6 +18,9 @@ type hook struct {
// HookServer wraps a packer.Hook implementation and makes it exportable // HookServer wraps a packer.Hook implementation and makes it exportable
// as part of a Golang RPC server. // as part of a Golang RPC server.
type HookServer struct { type HookServer struct {
context context.Context
contextCancel func()
hook packer.Hook hook packer.Hook
mux *muxBroker mux *muxBroker
} }
@ -35,6 +38,19 @@ func (h *hook) Run(ctx context.Context, name string, ui packer.Ui, comm packer.C
server.RegisterUi(ui) server.RegisterUi(ui)
go server.Serve() go server.Serve()
done := make(chan interface{})
defer close(done)
go func() {
select {
case <-ctx.Done():
log.Printf("Cancelling hook after context cancellation %v", ctx.Err())
if err := h.client.Call("Hook.Cancel", new(interface{}), new(interface{})); err != nil {
log.Printf("Error cancelling builder: %s", err)
}
case <-done:
}
}()
args := HookRunArgs{ args := HookRunArgs{
Name: name, Name: name,
Data: data, Data: data,
@ -44,24 +60,27 @@ func (h *hook) Run(ctx context.Context, name string, ui packer.Ui, comm packer.C
return h.client.Call("Hook.Run", &args, new(interface{})) return h.client.Call("Hook.Run", &args, new(interface{}))
} }
func (h *hook) Cancel() { func (h *HookServer) Run(args *HookRunArgs, reply *interface{}) error {
err := h.client.Call("Hook.Cancel", new(interface{}), new(interface{}))
if err != nil {
log.Printf("Hook.Cancel error: %s", err)
}
}
func (h *HookServer) Run(ctx context.Context, args *HookRunArgs, reply *interface{}) error {
client, err := newClientWithMux(h.mux, args.StreamId) client, err := newClientWithMux(h.mux, args.StreamId)
if err != nil { if err != nil {
return NewBasicError(err) return NewBasicError(err)
} }
defer client.Close() defer client.Close()
if err := h.hook.Run(ctx, args.Name, client.Ui(), client.Communicator(), args.Data); err != nil { if h.context == nil {
h.context, h.contextCancel = context.WithCancel(context.Background())
}
if err := h.hook.Run(h.context, args.Name, client.Ui(), client.Communicator(), args.Data); err != nil {
return NewBasicError(err) return NewBasicError(err)
} }
*reply = nil *reply = nil
return nil return nil
} }
func (h *HookServer) Cancel(args *interface{}, reply *interface{}) error {
if h.contextCancel != nil {
h.contextCancel()
}
return nil
}

View File

@ -2,59 +2,25 @@ package rpc
import ( import (
"context" "context"
"reflect"
"sync"
"testing" "testing"
"time"
"github.com/hashicorp/packer/packer" "github.com/hashicorp/packer/packer"
) )
func TestHookRPC(t *testing.T) {
// Create the UI to test
h := new(packer.MockHook)
ctx, cancel := context.WithCancel(context.Background())
// Serve
client, server := testClientServer(t)
defer client.Close()
defer server.Close()
server.RegisterHook(h)
hClient := client.Hook()
// Test Run
ui := &testUi{}
hClient.Run(ctx, "foo", ui, nil, 42)
if !h.RunCalled {
t.Fatal("should be called")
}
// Test Cancel
cancel()
if !h.CancelCalled {
t.Fatal("should be called")
}
}
func TestHook_Implements(t *testing.T) { func TestHook_Implements(t *testing.T) {
var _ packer.Hook = new(hook) var _ packer.Hook = new(hook)
} }
func TestHook_cancelWhileRun(t *testing.T) { func TestHook_cancelWhileRun(t *testing.T) {
var finishLock sync.Mutex topCtx, cancelTopCtx := context.WithCancel(context.Background())
finishOrder := make([]string, 0, 2)
h := &packer.MockHook{ h := &packer.MockHook{
RunFunc: func() error { RunFunc: func(ctx context.Context) error {
time.Sleep(100 * time.Millisecond) cancelTopCtx()
<-ctx.Done()
finishLock.Lock() return ctx.Err()
finishOrder = append(finishOrder, "run")
finishLock.Unlock()
return nil
}, },
} }
ctx, cancel := context.WithCancel(context.Background())
// Serve // Serve
client, server := testClientServer(t) client, server := testClientServer(t)
@ -64,26 +30,9 @@ func TestHook_cancelWhileRun(t *testing.T) {
hClient := client.Hook() hClient := client.Hook()
// Start the run // Start the run
finished := make(chan struct{}) err := hClient.Run(topCtx, "foo", nil, nil, nil)
go func() {
hClient.Run(ctx, "foo", nil, nil, nil)
close(finished)
}()
// Cancel it pretty quickly. if err == nil {
time.Sleep(10 * time.Millisecond) t.Fatal("should have errored")
cancel()
finishLock.Lock()
finishOrder = append(finishOrder, "cancel")
finishLock.Unlock()
// Verify things are good
<-finished
// Check the results
expected := []string{"cancel", "run"}
if !reflect.DeepEqual(finishOrder, expected) {
t.Fatalf("bad: %#v", finishOrder)
} }
} }

View File

@ -2,6 +2,7 @@ package rpc
import ( import (
"context" "context"
"log"
"net/rpc" "net/rpc"
"github.com/hashicorp/packer/packer" "github.com/hashicorp/packer/packer"
@ -17,6 +18,9 @@ type postProcessor struct {
// PostProcessorServer wraps a packer.PostProcessor implementation and makes it // PostProcessorServer wraps a packer.PostProcessor implementation and makes it
// exportable as part of a Golang RPC server. // exportable as part of a Golang RPC server.
type PostProcessorServer struct { type PostProcessorServer struct {
context context.Context
contextCancel func()
mux *muxBroker mux *muxBroker
p packer.PostProcessor p packer.PostProcessor
} }
@ -47,6 +51,20 @@ func (p *postProcessor) PostProcess(ctx context.Context, ui packer.Ui, a packer.
server.RegisterUi(ui) server.RegisterUi(ui)
go server.Serve() go server.Serve()
done := make(chan interface{})
defer close(done)
go func() {
select {
case <-ctx.Done():
log.Printf("Cancelling post-processor after context cancellation %v", ctx.Err())
if err := p.client.Call("PostProcessor.Cancel", new(interface{}), new(interface{})); err != nil {
log.Printf("Error cancelling post-processor: %s", err)
}
case <-done:
}
}()
var response PostProcessorProcessResponse var response PostProcessorProcessResponse
if err := p.client.Call("PostProcessor.PostProcess", nextId, &response); err != nil { if err := p.client.Call("PostProcessor.PostProcess", nextId, &response); err != nil {
return nil, false, err return nil, false, err
@ -73,15 +91,19 @@ func (p *PostProcessorServer) Configure(args *PostProcessorConfigureArgs, reply
return err return err
} }
func (p *PostProcessorServer) PostProcess(ctx context.Context, streamId uint32, reply *PostProcessorProcessResponse) error { func (p *PostProcessorServer) PostProcess(streamId uint32, reply *PostProcessorProcessResponse) error {
client, err := newClientWithMux(p.mux, streamId) client, err := newClientWithMux(p.mux, streamId)
if err != nil { if err != nil {
return NewBasicError(err) return NewBasicError(err)
} }
defer client.Close() defer client.Close()
if p.context == nil {
p.context, p.contextCancel = context.WithCancel(context.Background())
}
streamId = 0 streamId = 0
artifactResult, keep, err := p.p.PostProcess(ctx, client.Ui(), client.Artifact()) artifactResult, keep, err := p.p.PostProcess(p.context, client.Ui(), client.Artifact())
if err == nil && artifactResult != nil { if err == nil && artifactResult != nil {
streamId = p.mux.NextId() streamId = p.mux.NextId()
server := newServerWithMux(p.mux, streamId) server := newServerWithMux(p.mux, streamId)
@ -97,3 +119,10 @@ func (p *PostProcessorServer) PostProcess(ctx context.Context, streamId uint32,
return nil return nil
} }
func (b *PostProcessorServer) Cancel(args *interface{}, reply *interface{}) error {
if b.contextCancel != nil {
b.contextCancel()
}
return nil
}

View File

@ -17,6 +17,8 @@ type TestPostProcessor struct {
ppArtifact packer.Artifact ppArtifact packer.Artifact
ppArtifactId string ppArtifactId string
ppUi packer.Ui ppUi packer.Ui
postProcessFn func(context.Context) error
} }
func (pp *TestPostProcessor) Configure(v ...interface{}) error { func (pp *TestPostProcessor) Configure(v ...interface{}) error {
@ -30,6 +32,9 @@ func (pp *TestPostProcessor) PostProcess(ctx context.Context, ui packer.Ui, a pa
pp.ppArtifact = a pp.ppArtifact = a
pp.ppArtifactId = a.Id() pp.ppArtifactId = a.Id()
pp.ppUi = ui pp.ppUi = ui
if pp.postProcessFn != nil {
return testPostProcessorArtifact, false, pp.postProcessFn(ctx)
}
return testPostProcessorArtifact, false, nil return testPostProcessorArtifact, false, nil
} }
@ -84,6 +89,41 @@ func TestPostProcessorRPC(t *testing.T) {
} }
} }
func TestPostProcessorRPC_cancel(t *testing.T) {
topCtx, cancelTopCtx := context.WithCancel(context.Background())
p := new(TestPostProcessor)
p.postProcessFn = func(ctx context.Context) error {
cancelTopCtx()
<-ctx.Done()
return ctx.Err()
}
// Start the server
client, server := testClientServer(t)
defer client.Close()
defer server.Close()
if err := server.RegisterPostProcessor(p); err != nil {
panic(err)
}
ppClient := client.PostProcessor()
// Test Configure
config := 42
err := ppClient.Configure(config)
// Test PostProcess
a := &packer.MockArtifact{
IdValue: "ppTestId",
}
ui := new(testUi)
_, _, err = ppClient.PostProcess(topCtx, ui, a)
if err == nil {
t.Fatalf("should err")
}
}
func TestPostProcessor_Implements(t *testing.T) { func TestPostProcessor_Implements(t *testing.T) {
var raw interface{} var raw interface{}
raw = new(postProcessor) raw = new(postProcessor)

View File

@ -2,6 +2,7 @@ package rpc
import ( import (
"context" "context"
"log"
"net/rpc" "net/rpc"
"github.com/hashicorp/packer/packer" "github.com/hashicorp/packer/packer"
@ -17,6 +18,9 @@ type provisioner struct {
// ProvisionerServer wraps a packer.Provisioner implementation and makes it // ProvisionerServer wraps a packer.Provisioner implementation and makes it
// exportable as part of a Golang RPC server. // exportable as part of a Golang RPC server.
type ProvisionerServer struct { type ProvisionerServer struct {
context context.Context
contextCancel func()
p packer.Provisioner p packer.Provisioner
mux *muxBroker mux *muxBroker
} }
@ -41,23 +45,46 @@ func (p *provisioner) Provision(ctx context.Context, ui packer.Ui, comm packer.C
server.RegisterUi(ui) server.RegisterUi(ui)
go server.Serve() go server.Serve()
done := make(chan interface{})
defer close(done)
go func() {
select {
case <-ctx.Done():
log.Printf("Cancelling provisioner after context cancellation %v", ctx.Err())
if err := p.client.Call("Provisioner.Cancel", new(interface{}), new(interface{})); err != nil {
log.Printf("Error cancelling provisioner: %s", err)
}
case <-done:
}
}()
return p.client.Call("Provisioner.Provision", nextId, new(interface{})) return p.client.Call("Provisioner.Provision", nextId, new(interface{}))
} }
func (p *ProvisionerServer) Prepare(_ context.Context, args *ProvisionerPrepareArgs, reply *interface{}) error { func (p *ProvisionerServer) Prepare(args *ProvisionerPrepareArgs, reply *interface{}) error {
return p.p.Prepare(args.Configs...) return p.p.Prepare(args.Configs...)
} }
func (p *ProvisionerServer) Provision(ctx context.Context, streamId uint32, reply *interface{}) error { func (p *ProvisionerServer) Provision(streamId uint32, reply *interface{}) error {
client, err := newClientWithMux(p.mux, streamId) client, err := newClientWithMux(p.mux, streamId)
if err != nil { if err != nil {
return NewBasicError(err) return NewBasicError(err)
} }
defer client.Close() defer client.Close()
if err := p.p.Provision(ctx, client.Ui(), client.Communicator()); err != nil { if p.context == nil {
p.context, p.contextCancel = context.WithCancel(context.Background())
}
if err := p.p.Provision(p.context, client.Ui(), client.Communicator()); err != nil {
return NewBasicError(err) return NewBasicError(err)
} }
return nil return nil
} }
func (p *ProvisionerServer) Cancel(args *interface{}, reply *interface{}) error {
p.contextCancel()
return nil
}

View File

@ -9,8 +9,15 @@ import (
) )
func TestProvisionerRPC(t *testing.T) { func TestProvisionerRPC(t *testing.T) {
topCtx, topCtxCancel := context.WithCancel(context.Background())
// Create the interface to test // Create the interface to test
p := new(packer.MockProvisioner) p := new(packer.MockProvisioner)
p.ProvFunc = func(ctx context.Context) error {
topCtxCancel()
<-ctx.Done()
return ctx.Err()
}
// Start the server // Start the server
client, server := testClientServer(t) client, server := testClientServer(t)
@ -18,7 +25,6 @@ func TestProvisionerRPC(t *testing.T) {
defer server.Close() defer server.Close()
server.RegisterProvisioner(p) server.RegisterProvisioner(p)
pClient := client.Provisioner() pClient := client.Provisioner()
ctx, cancel := context.WithCancel(context.Background())
// Test Prepare // Test Prepare
config := 42 config := 42
pClient.Prepare(config) pClient.Prepare(config)
@ -33,18 +39,13 @@ func TestProvisionerRPC(t *testing.T) {
// Test Provision // Test Provision
ui := &testUi{} ui := &testUi{}
comm := &packer.MockCommunicator{} comm := &packer.MockCommunicator{}
if err := pClient.Provision(ctx, ui, comm); err != nil { if err := pClient.Provision(topCtx, ui, comm); err == nil {
t.Fatalf("err: %v", err) t.Fatalf("Provison should have err")
} }
if !p.ProvCalled { if !p.ProvCalled {
t.Fatal("should be called") t.Fatal("should be called")
} }
// Test Cancel
cancel()
if !p.CancelCalled {
t.Fatal("cancel should be called")
}
} }
func TestProvisioner_Implements(t *testing.T) { func TestProvisioner_Implements(t *testing.T) {

View File

@ -32,12 +32,15 @@ type Server struct {
} }
// NewServer returns a new Packer RPC server. // NewServer returns a new Packer RPC server.
func NewServer(conn io.ReadWriteCloser) *Server { func NewServer(conn io.ReadWriteCloser) (*Server, error) {
mux, _ := newMuxBrokerServer(conn) mux, err := newMuxBrokerServer(conn)
if err != nil {
return nil, err
}
result := newServerWithMux(mux, 0) result := newServerWithMux(mux, 0)
result.closeMux = true result.closeMux = true
go mux.Run() go mux.Run()
return result return result, nil
} }
func newServerWithMux(mux *muxBroker, streamId uint32) *Server { func newServerWithMux(mux *muxBroker, streamId uint32) *Server {

View File

@ -342,40 +342,27 @@ func TestProvision_Cancel(t *testing.T) {
ui := testUi() ui := testUi()
p := new(Provisioner) p := new(Provisioner)
var err error
comm := new(packer.MockCommunicator) comm := new(packer.MockCommunicator)
p.Prepare(config) p.Prepare(config)
waitStart := make(chan bool) done := make(chan error)
waitDone := make(chan bool)
topCtx, cancelTopCtx := context.WithCancel(context.Background())
// Block until cancel comes through // Block until cancel comes through
waitForCommunicator = func(ctx context.Context, p *Provisioner) error { waitForCommunicator = func(ctx context.Context, p *Provisioner) error {
waitStart <- true cancelTopCtx()
panic("this test is incorrect") <-ctx.Done()
for { return ctx.Err()
select {
case <-p.cancel:
}
}
} }
ctx, cancel := context.WithCancel(context.Background())
// Create two go routines to provision and cancel in parallel // Create two go routines to provision and cancel in parallel
// Provision will block until cancel happens // Provision will block until cancel happens
go func() { go func() {
err = p.Provision(ctx, ui, comm) done <- p.Provision(topCtx, ui, comm)
waitDone <- true
}() }()
go func() {
<-waitStart
cancel()
}()
<-waitDone
// Expect interrupt error // Expect interrupt error
if err == nil { if err := <-done; err == nil {
t.Fatal("should have error") t.Fatal("should have error")
} }
} }