diff --git a/packer/provisioner_test.go b/packer/provisioner_test.go index ac6ab591e..d4b34e067 100644 --- a/packer/provisioner_test.go +++ b/packer/provisioner_test.go @@ -2,6 +2,7 @@ package packer import ( "context" + "fmt" "testing" "time" ) @@ -126,30 +127,29 @@ func TestPausedProvisionerProvision(t *testing.T) { } func TestPausedProvisionerProvision_waits(t *testing.T) { - mock := new(MockProvisioner) + startTime := time.Now() + waitTime := 50 * time.Millisecond + prov := &PausedProvisioner{ - PauseBefore: 50 * time.Millisecond, - Provisioner: mock, + PauseBefore: waitTime, + Provisioner: &MockProvisioner{ + ProvFunc: func(context.Context) error { + timeSinceStartTime := time.Since(startTime) + if timeSinceStartTime < waitTime { + return fmt.Errorf("Spent not enough time waiting: %s", timeSinceStartTime) + } + if timeSinceStartTime > waitTime+10*time.Millisecond { + return fmt.Errorf("Spent too much time waiting: %s", timeSinceStartTime) + } + return nil + }, + }, } - dataCh := make(chan struct{}) - mock.ProvFunc = func(context.Context) error { - close(dataCh) - return nil - } + err := prov.Provision(context.Background(), testUi(), new(MockCommunicator)) - go prov.Provision(context.Background(), testUi(), new(MockCommunicator)) - - select { - case <-time.After(10 * time.Millisecond): - case <-dataCh: - t.Fatal("should not be called") - } - - select { - case <-time.After(100 * time.Millisecond): - t.Fatal("never called") - case <-dataCh: + if err != nil { + t.Fatalf("prov failed: %v", err) } }