packer/plugin: Count number of interrupts atomically

This commit is contained in:
Mitchell Hashimoto 2013-08-30 16:29:21 -07:00
parent 71379bc8d7
commit 893c9e02c0
1 changed files with 15 additions and 8 deletions

View File

@ -19,8 +19,14 @@ import (
"os/signal"
"runtime"
"strconv"
"sync/atomic"
)
// This is a count of the number of interrupts the process has received.
// This is updated with sync/atomic whenever a SIGINT is received and can
// be checked by the plugin safely to take action.
var Interrupts int32 = 0
const MagicCookieKey = "PACKER_PLUGIN_MAGIC_COOKIE"
const MagicCookieValue = "d602bf8f470bc67ca7faa0386276bbdd4330efaf76d1a219cb4d6991ca9872b2"
@ -86,17 +92,18 @@ func serve(server *rpc.Server) (err error) {
return
}
// Registers a signal handler to "swallow" interrupts so that the
// Registers a signal handler to swallow and count interrupts so that the
// plugin isn't killed. The main host Packer process is responsible
// for killing the plugins when interrupted.
func swallowInterrupts() {
func countInterrupts() {
ch := make(chan os.Signal, 1)
signal.Notify(ch, os.Interrupt)
go func() {
for {
<-ch
log.Println("Received interrupt signal. Ignoring.")
newCount := atomic.AddInt32(&Interrupts, 1)
log.Printf("Received interrupt signal (count: %d). Ignoring.", newCount)
}
}()
}
@ -108,7 +115,7 @@ func ServeBuilder(builder packer.Builder) {
server := rpc.NewServer()
packrpc.RegisterBuilder(server, builder)
swallowInterrupts()
countInterrupts()
if err := serve(server); err != nil {
log.Printf("ERROR: %s", err)
os.Exit(1)
@ -122,7 +129,7 @@ func ServeCommand(command packer.Command) {
server := rpc.NewServer()
packrpc.RegisterCommand(server, command)
swallowInterrupts()
countInterrupts()
if err := serve(server); err != nil {
log.Printf("ERROR: %s", err)
os.Exit(1)
@ -136,7 +143,7 @@ func ServeHook(hook packer.Hook) {
server := rpc.NewServer()
packrpc.RegisterHook(server, hook)
swallowInterrupts()
countInterrupts()
if err := serve(server); err != nil {
log.Printf("ERROR: %s", err)
os.Exit(1)
@ -150,7 +157,7 @@ func ServePostProcessor(p packer.PostProcessor) {
server := rpc.NewServer()
packrpc.RegisterPostProcessor(server, p)
swallowInterrupts()
countInterrupts()
if err := serve(server); err != nil {
log.Printf("ERROR: %s", err)
os.Exit(1)
@ -164,7 +171,7 @@ func ServeProvisioner(p packer.Provisioner) {
server := rpc.NewServer()
packrpc.RegisterProvisioner(server, p)
swallowInterrupts()
countInterrupts()
if err := serve(server); err != nil {
log.Printf("ERROR: %s", err)
os.Exit(1)