From 840ddb4f20d2203445c26c90e9192bd20d6fc0dd Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Sun, 14 Jun 2015 11:14:47 -0700 Subject: [PATCH] provisioner/windows-restart --- plugin/provisioner-windows-restart/main.go | 15 + provisioner/windows-restart/provisioner.go | 194 ++++++++++ .../windows-restart/provisioner_test.go | 355 ++++++++++++++++++ 3 files changed, 564 insertions(+) create mode 100644 plugin/provisioner-windows-restart/main.go create mode 100644 provisioner/windows-restart/provisioner.go create mode 100644 provisioner/windows-restart/provisioner_test.go diff --git a/plugin/provisioner-windows-restart/main.go b/plugin/provisioner-windows-restart/main.go new file mode 100644 index 000000000..0adf82216 --- /dev/null +++ b/plugin/provisioner-windows-restart/main.go @@ -0,0 +1,15 @@ +package main + +import ( + "github.com/mitchellh/packer/packer/plugin" + "github.com/mitchellh/packer/provisioner/windows-restart" +) + +func main() { + server, err := plugin.Server() + if err != nil { + panic(err) + } + server.RegisterProvisioner(new(restart.Provisioner)) + server.Serve() +} diff --git a/provisioner/windows-restart/provisioner.go b/provisioner/windows-restart/provisioner.go new file mode 100644 index 000000000..234980183 --- /dev/null +++ b/provisioner/windows-restart/provisioner.go @@ -0,0 +1,194 @@ +package restart + +import ( + "fmt" + "log" + "time" + + "github.com/masterzen/winrm/winrm" + "github.com/mitchellh/packer/common" + "github.com/mitchellh/packer/helper/config" + "github.com/mitchellh/packer/packer" + "github.com/mitchellh/packer/template/interpolate" +) + +var DefaultRestartCommand = "shutdown /r /c \"packer restart\" /t 5 && net stop winrm" +var DefaultRestartCheckCommand = winrm.Powershell(`echo "${env:COMPUTERNAME} restarted."`) +var retryableSleep = 5 * time.Second + +type Config struct { + common.PackerConfig `mapstructure:",squash"` + + // The command used to restart the guest machine + RestartCommand string `mapstructure:"restart_command"` + + // The command used to check if the guest machine has restarted + // The output of this command will be displayed to the user + RestartCheckCommand string `mapstructure:"restart_check_command"` + + // The timeout for waiting for the machine to restart + RestartTimeout time.Duration `mapstructure:"restart_timeout"` + + ctx interpolate.Context +} + +type Provisioner struct { + config Config + comm packer.Communicator + ui packer.Ui + cancel chan struct{} +} + +func (p *Provisioner) Prepare(raws ...interface{}) error { + err := config.Decode(&p.config, &config.DecodeOpts{ + Interpolate: true, + InterpolateFilter: &interpolate.RenderFilter{ + Exclude: []string{ + "execute_command", + }, + }, + }, raws...) + if err != nil { + return err + } + + if p.config.RestartCommand == "" { + p.config.RestartCommand = DefaultRestartCommand + } + + if p.config.RestartCheckCommand == "" { + p.config.RestartCheckCommand = DefaultRestartCheckCommand + } + + if p.config.RestartTimeout == 0 { + p.config.RestartTimeout = 5 * time.Minute + } + + return nil +} + +func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { + ui.Say("Restarting Machine") + p.comm = comm + p.ui = ui + p.cancel = make(chan struct{}) + + var cmd *packer.RemoteCmd + command := p.config.RestartCommand + err := p.retryable(func() error { + cmd = &packer.RemoteCmd{Command: command} + return cmd.StartWithUi(comm, ui) + }) + + if err != nil { + return err + } + + if cmd.ExitStatus != 0 { + return fmt.Errorf("Restart script exited with non-zero exit status: %d", cmd.ExitStatus) + } + + return waitForRestart(p) +} + +var waitForRestart = func(p *Provisioner) error { + ui := p.ui + ui.Say("Waiting for machine to restart...") + waitDone := make(chan bool, 1) + timeout := time.After(p.config.RestartTimeout) + var err error + + go func() { + log.Printf("Waiting for machine to become available...") + err = waitForCommunicator(p) + waitDone <- true + }() + + log.Printf("Waiting for machine to reboot with timeout: %s", p.config.RestartTimeout) + +WaitLoop: + for { + // Wait for either WinRM to become available, a timeout to occur, + // or an interrupt to come through. + select { + case <-waitDone: + if err != nil { + ui.Error(fmt.Sprintf("Error waiting for WinRM: %s", err)) + return err + } + + ui.Say("Machine successfully restarted, moving on") + close(p.cancel) + break WaitLoop + case <-timeout: + err := fmt.Errorf("Timeout waiting for WinRM.") + ui.Error(err.Error()) + close(p.cancel) + return err + case <-p.cancel: + close(waitDone) + return fmt.Errorf("Interrupt detected, quitting waiting for machine to restart") + break WaitLoop + } + } + + return nil + +} + +var waitForCommunicator = func(p *Provisioner) error { + cmd := &packer.RemoteCmd{Command: p.config.RestartCheckCommand} + + for { + select { + case <-p.cancel: + log.Println("Communicator wait cancelled, exiting loop") + return fmt.Errorf("Communicator wait cancelled") + case <-time.After(retryableSleep): + } + + log.Printf("Attempting to communicator to machine with: '%s'", cmd.Command) + + err := cmd.StartWithUi(p.comm, p.ui) + if err != nil { + log.Printf("Communication connection err: %s", err) + continue + } + + log.Printf("Connected to machine") + break + } + + return nil +} + +func (p *Provisioner) Cancel() { + log.Printf("Received interrupt Cancel()") + close(p.cancel) +} + +// retryable will retry the given function over and over until a +// non-error is returned. +func (p *Provisioner) retryable(f func() error) error { + startTimeout := time.After(p.config.RestartTimeout) + for { + var err error + if err = f(); err == nil { + return nil + } + + // Create an error and log it + err = fmt.Errorf("Retryable error: %s", err) + log.Printf(err.Error()) + + // Check if we timed out, otherwise we retry. It is safe to + // retry since the only error case above is if the command + // failed to START. + select { + case <-startTimeout: + return err + default: + time.Sleep(retryableSleep) + } + } +} diff --git a/provisioner/windows-restart/provisioner_test.go b/provisioner/windows-restart/provisioner_test.go new file mode 100644 index 000000000..f0f2766e3 --- /dev/null +++ b/provisioner/windows-restart/provisioner_test.go @@ -0,0 +1,355 @@ +package restart + +import ( + "bytes" + "errors" + "fmt" + "github.com/mitchellh/packer/packer" + "testing" + "time" +) + +func testConfig() map[string]interface{} { + return map[string]interface{}{} +} + +func TestProvisioner_Impl(t *testing.T) { + var raw interface{} + raw = &Provisioner{} + if _, ok := raw.(packer.Provisioner); !ok { + t.Fatalf("must be a Provisioner") + } +} + +func TestProvisionerPrepare_Defaults(t *testing.T) { + var p Provisioner + config := testConfig() + + err := p.Prepare(config) + if err != nil { + t.Fatalf("err: %s", err) + } + + if p.config.RestartTimeout != 5*time.Minute { + t.Errorf("unexpected remote path: %s", p.config.RestartTimeout) + } + + if p.config.RestartCommand != "shutdown /r /c \"packer restart\" /t 5 && net stop winrm" { + t.Errorf("unexpected remote path: %s", p.config.RestartCommand) + } +} + +func TestProvisionerPrepare_ConfigRetryTimeout(t *testing.T) { + var p Provisioner + config := testConfig() + config["restart_timeout"] = "1m" + + err := p.Prepare(config) + if err != nil { + t.Fatalf("err: %s", err) + } + + if p.config.RestartTimeout != 1*time.Minute { + t.Errorf("unexpected remote path: %s", p.config.RestartTimeout) + } +} + +func TestProvisionerPrepare_ConfigErrors(t *testing.T) { + var p Provisioner + config := testConfig() + config["restart_timeout"] = "m" + + err := p.Prepare(config) + if err == nil { + t.Fatal("Expected error parsing restart_timeout but did not receive one.") + } +} + +func TestProvisionerPrepare_InvalidKey(t *testing.T) { + var p Provisioner + config := testConfig() + + // Add a random key + config["i_should_not_be_valid"] = true + err := p.Prepare(config) + if err == nil { + t.Fatal("should have error") + } +} + +func testUi() *packer.BasicUi { + return &packer.BasicUi{ + Reader: new(bytes.Buffer), + Writer: new(bytes.Buffer), + ErrorWriter: new(bytes.Buffer), + } +} + +func TestProvisionerProvision_Success(t *testing.T) { + config := testConfig() + + // Defaults provided by Packer + ui := testUi() + p := new(Provisioner) + + // Defaults provided by Packer + comm := new(packer.MockCommunicator) + p.Prepare(config) + waitForCommunicatorOld := waitForCommunicator + waitForCommunicator = func(p *Provisioner) error { + return nil + } + err := p.Provision(ui, comm) + if err != nil { + t.Fatal("should not have error") + } + + expectedCommand := DefaultRestartCommand + + // Should run the command without alteration + if comm.StartCmd.Command != expectedCommand { + t.Fatalf("Expect command to be: %s, got %s", expectedCommand, comm.StartCmd.Command) + } + // Set this back! + waitForCommunicator = waitForCommunicatorOld +} + +func TestProvisionerProvision_CustomCommand(t *testing.T) { + config := testConfig() + + // Defaults provided by Packer + ui := testUi() + p := new(Provisioner) + expectedCommand := "specialrestart.exe -NOW" + config["restart_command"] = expectedCommand + + // Defaults provided by Packer + comm := new(packer.MockCommunicator) + p.Prepare(config) + waitForCommunicatorOld := waitForCommunicator + waitForCommunicator = func(p *Provisioner) error { + return nil + } + err := p.Provision(ui, comm) + if err != nil { + t.Fatal("should not have error") + } + + // Should run the command without alteration + if comm.StartCmd.Command != expectedCommand { + t.Fatalf("Expect command to be: %s, got %s", expectedCommand, comm.StartCmd.Command) + } + // Set this back! + waitForCommunicator = waitForCommunicatorOld +} + +func TestProvisionerProvision_RestartCommandFail(t *testing.T) { + config := testConfig() + ui := testUi() + p := new(Provisioner) + comm := new(packer.MockCommunicator) + comm.StartStderr = "WinRM terminated" + comm.StartExitStatus = 1 + + p.Prepare(config) + err := p.Provision(ui, comm) + if err == nil { + t.Fatal("should have error") + } +} +func TestProvisionerProvision_WaitForRestartFail(t *testing.T) { + config := testConfig() + + // Defaults provided by Packer + ui := testUi() + p := new(Provisioner) + + // Defaults provided by Packer + comm := new(packer.MockCommunicator) + p.Prepare(config) + waitForCommunicatorOld := waitForCommunicator + waitForCommunicator = func(p *Provisioner) error { + return fmt.Errorf("Machine did not restart properly") + } + err := p.Provision(ui, comm) + if err == nil { + t.Fatal("should have error") + } + + // Set this back! + waitForCommunicator = waitForCommunicatorOld +} + +func TestProvision_waitForRestartTimeout(t *testing.T) { + retryableSleep = 10 * time.Millisecond + config := testConfig() + config["restart_timeout"] = "1ms" + ui := testUi() + p := new(Provisioner) + comm := new(packer.MockCommunicator) + var err error + + p.Prepare(config) + waitForCommunicatorOld := waitForCommunicator + waitDone := make(chan bool) + + // Block until cancel comes through + waitForCommunicator = func(p *Provisioner) error { + for { + select { + case <-waitDone: + } + } + } + + go func() { + err = p.Provision(ui, comm) + waitDone <- true + }() + <-waitDone + + if err == nil { + t.Fatal("should not have error") + } + + // Set this back! + waitForCommunicator = waitForCommunicatorOld + +} + +func TestProvision_waitForCommunicator(t *testing.T) { + config := testConfig() + + // Defaults provided by Packer + ui := testUi() + p := new(Provisioner) + + // Defaults provided by Packer + comm := new(packer.MockCommunicator) + p.comm = comm + p.ui = ui + comm.StartStderr = "WinRM terminated" + comm.StartExitStatus = 1 + p.Prepare(config) + err := waitForCommunicator(p) + + if err != nil { + t.Fatal("should not have error, got: %s", err.Error()) + } + + expectedCommand := DefaultRestartCheckCommand + + // Should run the command without alteration + if comm.StartCmd.Command != expectedCommand { + t.Fatalf("Expect command to be: %s, got %s", expectedCommand, comm.StartCmd.Command) + } +} + +func TestProvision_waitForCommunicatorWithCancel(t *testing.T) { + config := testConfig() + + // Defaults provided by Packer + ui := testUi() + p := new(Provisioner) + + // Defaults provided by Packer + comm := new(packer.MockCommunicator) + p.comm = comm + p.ui = ui + retryableSleep = 10 * time.Millisecond + p.cancel = make(chan struct{}) + var err error + + comm.StartStderr = "WinRM terminated" + comm.StartExitStatus = 1 // Always fail + p.Prepare(config) + + // Run 2 goroutines; + // 1st to call waitForCommunicator (that will always fail) + // 2nd to cancel the operation + waitDone := make(chan bool) + go func() { + err = waitForCommunicator(p) + }() + + go func() { + p.Cancel() + waitDone <- true + }() + <-waitDone + + // Expect a Cancel error + if err == nil { + t.Fatalf("Should have err") + } +} + +func TestRetryable(t *testing.T) { + config := testConfig() + + count := 0 + retryMe := func() error { + t.Logf("RetryMe, attempt number %d", count) + if count == 2 { + return nil + } + count++ + return errors.New(fmt.Sprintf("Still waiting %d more times...", 2-count)) + } + retryableSleep = 50 * time.Millisecond + p := new(Provisioner) + p.config.RestartTimeout = 155 * time.Millisecond + err := p.Prepare(config) + err = p.retryable(retryMe) + if err != nil { + t.Fatalf("should not have error retrying funuction") + } + + count = 0 + p.config.RestartTimeout = 10 * time.Millisecond + err = p.Prepare(config) + err = p.retryable(retryMe) + if err == nil { + t.Fatalf("should have error retrying funuction") + } +} + +func TestProvision_Cancel(t *testing.T) { + config := testConfig() + + // Defaults provided by Packer + ui := testUi() + p := new(Provisioner) + + var err error + + comm := new(packer.MockCommunicator) + p.Prepare(config) + waitDone := make(chan bool) + + // Block until cancel comes through + waitForCommunicator = func(p *Provisioner) error { + for { + select { + case <-waitDone: + } + } + } + + // Create two go routines to provision and cancel in parallel + // Provision will block until cancel happens + go func() { + err = p.Provision(ui, comm) + waitDone <- true + }() + + go func() { + p.Cancel() + }() + <-waitDone + + // Expect interupt error + if err == nil { + t.Fatal("should have error") + } +}