diff --git a/packer/plugin/plugin.go b/packer/plugin/plugin.go index 223de9b53..48da97445 100644 --- a/packer/plugin/plugin.go +++ b/packer/plugin/plugin.go @@ -19,8 +19,14 @@ import ( "os/signal" "runtime" "strconv" + "sync/atomic" ) +// This is a count of the number of interrupts the process has received. +// This is updated with sync/atomic whenever a SIGINT is received and can +// be checked by the plugin safely to take action. +var Interrupts int32 = 0 + const MagicCookieKey = "PACKER_PLUGIN_MAGIC_COOKIE" const MagicCookieValue = "d602bf8f470bc67ca7faa0386276bbdd4330efaf76d1a219cb4d6991ca9872b2" @@ -86,17 +92,18 @@ func serve(server *rpc.Server) (err error) { return } -// Registers a signal handler to "swallow" interrupts so that the +// Registers a signal handler to swallow and count interrupts so that the // plugin isn't killed. The main host Packer process is responsible // for killing the plugins when interrupted. -func swallowInterrupts() { +func countInterrupts() { ch := make(chan os.Signal, 1) signal.Notify(ch, os.Interrupt) go func() { for { <-ch - log.Println("Received interrupt signal. Ignoring.") + newCount := atomic.AddInt32(&Interrupts, 1) + log.Printf("Received interrupt signal (count: %d). Ignoring.", newCount) } }() } @@ -108,7 +115,7 @@ func ServeBuilder(builder packer.Builder) { server := rpc.NewServer() packrpc.RegisterBuilder(server, builder) - swallowInterrupts() + countInterrupts() if err := serve(server); err != nil { log.Printf("ERROR: %s", err) os.Exit(1) @@ -122,7 +129,7 @@ func ServeCommand(command packer.Command) { server := rpc.NewServer() packrpc.RegisterCommand(server, command) - swallowInterrupts() + countInterrupts() if err := serve(server); err != nil { log.Printf("ERROR: %s", err) os.Exit(1) @@ -136,7 +143,7 @@ func ServeHook(hook packer.Hook) { server := rpc.NewServer() packrpc.RegisterHook(server, hook) - swallowInterrupts() + countInterrupts() if err := serve(server); err != nil { log.Printf("ERROR: %s", err) os.Exit(1) @@ -150,7 +157,7 @@ func ServePostProcessor(p packer.PostProcessor) { server := rpc.NewServer() packrpc.RegisterPostProcessor(server, p) - swallowInterrupts() + countInterrupts() if err := serve(server); err != nil { log.Printf("ERROR: %s", err) os.Exit(1) @@ -164,7 +171,7 @@ func ServeProvisioner(p packer.Provisioner) { server := rpc.NewServer() packrpc.RegisterProvisioner(server, p) - swallowInterrupts() + countInterrupts() if err := serve(server); err != nil { log.Printf("ERROR: %s", err) os.Exit(1)