From 2d5ca19b373764d146352d1be6f26f26994baa81 Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Wed, 18 Sep 2013 17:15:48 -0700 Subject: [PATCH] packer/rpc: set keep-alive on all RPC connections [GH-416] --- CHANGELOG.md | 2 ++ packer/rpc/build.go | 4 ++-- packer/rpc/builder.go | 7 +++---- packer/rpc/command.go | 2 +- packer/rpc/communicator.go | 12 ++++++------ packer/rpc/dial.go | 33 +++++++++++++++++++++++++++++++++ packer/rpc/environment.go | 12 ++++++------ packer/rpc/hook.go | 2 +- packer/rpc/post_processor.go | 4 ++-- packer/rpc/provisioner.go | 2 +- 10 files changed, 57 insertions(+), 23 deletions(-) create mode 100644 packer/rpc/dial.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c167e1d4..8a85c2ce1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ IMPROVEMENTS: BUG FIXES: +* core: Set TCP KeepAlives on internally created RPC connections so that + they don't die. [GH-416] * builder/amazon/all: While waiting for AMI, will detect "failed" state. * builder/amazon/all: Waiting for state will detect if the resource (AMI, instance, etc.) disappears from under it. diff --git a/packer/rpc/build.go b/packer/rpc/build.go index 6d81ce870..8f21b4ce9 100644 --- a/packer/rpc/build.go +++ b/packer/rpc/build.go @@ -52,7 +52,7 @@ func (b *build) Run(ui packer.Ui, cache packer.Cache) ([]packer.Artifact, error) artifacts := make([]packer.Artifact, len(result)) for i, addr := range result { - client, err := rpc.Dial("tcp", addr) + client, err := rpcDial(addr) if err != nil { return nil, err } @@ -92,7 +92,7 @@ func (b *BuildServer) Prepare(v map[string]string, reply *error) error { } func (b *BuildServer) Run(args *BuildRunArgs, reply *[]string) error { - client, err := rpc.Dial("tcp", args.UiRPCAddress) + client, err := rpcDial(args.UiRPCAddress) if err != nil { return err } diff --git a/packer/rpc/builder.go b/packer/rpc/builder.go index 8c56e4959..5e2c20d77 100644 --- a/packer/rpc/builder.go +++ b/packer/rpc/builder.go @@ -5,7 +5,6 @@ import ( "fmt" "github.com/mitchellh/packer/packer" "log" - "net" "net/rpc" ) @@ -95,7 +94,7 @@ func (b *builder) Run(ui packer.Ui, hook packer.Hook, cache packer.Cache) (packe return nil, nil } - client, err := rpc.Dial("tcp", response.RPCAddress) + client, err := rpcDial(response.RPCAddress) if err != nil { return nil, err } @@ -119,12 +118,12 @@ func (b *BuilderServer) Prepare(args *BuilderPrepareArgs, reply *error) error { } func (b *BuilderServer) Run(args *BuilderRunArgs, reply *interface{}) error { - client, err := rpc.Dial("tcp", args.RPCAddress) + client, err := rpcDial(args.RPCAddress) if err != nil { return err } - responseC, err := net.Dial("tcp", args.ResponseAddress) + responseC, err := tcpDial(args.ResponseAddress) if err != nil { return err } diff --git a/packer/rpc/command.go b/packer/rpc/command.go index 18cd5667e..3e2b48b2f 100644 --- a/packer/rpc/command.go +++ b/packer/rpc/command.go @@ -66,7 +66,7 @@ func (c *CommandServer) Help(args *interface{}, reply *string) error { } func (c *CommandServer) Run(args *CommandRunArgs, reply *int) error { - client, err := rpc.Dial("tcp", args.RPCAddress) + client, err := rpcDial(args.RPCAddress) if err != nil { return err } diff --git a/packer/rpc/communicator.go b/packer/rpc/communicator.go index 21b507f0b..77e321153 100644 --- a/packer/rpc/communicator.go +++ b/packer/rpc/communicator.go @@ -177,7 +177,7 @@ func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface toClose := make([]net.Conn, 0) if args.StdinAddress != "" { - stdinC, err := net.Dial("tcp", args.StdinAddress) + stdinC, err := tcpDial(args.StdinAddress) if err != nil { return err } @@ -187,7 +187,7 @@ func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface } if args.StdoutAddress != "" { - stdoutC, err := net.Dial("tcp", args.StdoutAddress) + stdoutC, err := tcpDial(args.StdoutAddress) if err != nil { return err } @@ -197,7 +197,7 @@ func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface } if args.StderrAddress != "" { - stderrC, err := net.Dial("tcp", args.StderrAddress) + stderrC, err := tcpDial(args.StderrAddress) if err != nil { return err } @@ -208,7 +208,7 @@ func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface // Connect to the response address so we can write our result to it // when ready. - responseC, err := net.Dial("tcp", args.ResponseAddress) + responseC, err := tcpDial(args.ResponseAddress) if err != nil { return err } @@ -234,7 +234,7 @@ func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface } func (c *CommunicatorServer) Upload(args *CommunicatorUploadArgs, reply *interface{}) (err error) { - readerC, err := net.Dial("tcp", args.ReaderAddress) + readerC, err := tcpDial(args.ReaderAddress) if err != nil { return } @@ -250,7 +250,7 @@ func (c *CommunicatorServer) UploadDir(args *CommunicatorUploadDirArgs, reply *e } func (c *CommunicatorServer) Download(args *CommunicatorDownloadArgs, reply *interface{}) (err error) { - writerC, err := net.Dial("tcp", args.WriterAddress) + writerC, err := tcpDial(args.WriterAddress) if err != nil { return } diff --git a/packer/rpc/dial.go b/packer/rpc/dial.go new file mode 100644 index 000000000..10e2cad14 --- /dev/null +++ b/packer/rpc/dial.go @@ -0,0 +1,33 @@ +package rpc + +import ( + "net" + "net/rpc" +) + +// rpcDial makes a TCP connection to a remote RPC server and returns +// the client. This will set the connection up properly so that keep-alives +// are set and so on and should be used to make all RPC connections within +// this package. +func rpcDial(address string) (*rpc.Client, error) { + tcpConn, err := tcpDial(address) + if err != nil { + return nil, err + } + + // Create an RPC client around our connection + return rpc.NewClient(tcpConn), nil +} + +// tcpDial connects via TCP to the designated address. +func tcpDial(address string) (*net.TCPConn, error) { + conn, err := net.Dial("tcp", address) + if err != nil { + return nil, err + } + + // Set a keep-alive so that the connection stays alive even when idle + tcpConn := conn.(*net.TCPConn) + tcpConn.SetKeepAlive(true) + return tcpConn, nil +} diff --git a/packer/rpc/environment.go b/packer/rpc/environment.go index 8ebf709c0..36db72c56 100644 --- a/packer/rpc/environment.go +++ b/packer/rpc/environment.go @@ -28,7 +28,7 @@ func (e *Environment) Builder(name string) (b packer.Builder, err error) { return } - client, err := rpc.Dial("tcp", reply) + client, err := rpcDial(reply) if err != nil { return } @@ -43,7 +43,7 @@ func (e *Environment) Cache() packer.Cache { panic(err) } - client, err := rpc.Dial("tcp", reply) + client, err := rpcDial(reply) if err != nil { panic(err) } @@ -64,7 +64,7 @@ func (e *Environment) Hook(name string) (h packer.Hook, err error) { return } - client, err := rpc.Dial("tcp", reply) + client, err := rpcDial(reply) if err != nil { return } @@ -80,7 +80,7 @@ func (e *Environment) PostProcessor(name string) (p packer.PostProcessor, err er return } - client, err := rpc.Dial("tcp", reply) + client, err := rpcDial(reply) if err != nil { return } @@ -96,7 +96,7 @@ func (e *Environment) Provisioner(name string) (p packer.Provisioner, err error) return } - client, err := rpc.Dial("tcp", reply) + client, err := rpcDial(reply) if err != nil { return } @@ -109,7 +109,7 @@ func (e *Environment) Ui() packer.Ui { var reply string e.client.Call("Environment.Ui", new(interface{}), &reply) - client, err := rpc.Dial("tcp", reply) + client, err := rpcDial(reply) if err != nil { panic(err) } diff --git a/packer/rpc/hook.go b/packer/rpc/hook.go index 687d991a7..223b96df2 100644 --- a/packer/rpc/hook.go +++ b/packer/rpc/hook.go @@ -46,7 +46,7 @@ func (h *hook) Cancel() { } func (h *HookServer) Run(args *HookRunArgs, reply *interface{}) error { - client, err := rpc.Dial("tcp", args.RPCAddress) + client, err := rpcDial(args.RPCAddress) if err != nil { return err } diff --git a/packer/rpc/post_processor.go b/packer/rpc/post_processor.go index fb43cb7d9..0a5eaefd3 100644 --- a/packer/rpc/post_processor.go +++ b/packer/rpc/post_processor.go @@ -57,7 +57,7 @@ func (p *postProcessor) PostProcess(ui packer.Ui, a packer.Artifact) (packer.Art return nil, false, nil } - client, err := rpc.Dial("tcp", response.RPCAddress) + client, err := rpcDial(response.RPCAddress) if err != nil { return nil, false, err } @@ -75,7 +75,7 @@ func (p *PostProcessorServer) Configure(args *PostProcessorConfigureArgs, reply } func (p *PostProcessorServer) PostProcess(address string, reply *PostProcessorProcessResponse) error { - client, err := rpc.Dial("tcp", address) + client, err := rpcDial(address) if err != nil { return err } diff --git a/packer/rpc/provisioner.go b/packer/rpc/provisioner.go index 7d3ed1617..4cd329d6b 100644 --- a/packer/rpc/provisioner.go +++ b/packer/rpc/provisioner.go @@ -65,7 +65,7 @@ func (p *ProvisionerServer) Prepare(args *ProvisionerPrepareArgs, reply *error) } func (p *ProvisionerServer) Provision(args *ProvisionerProvisionArgs, reply *interface{}) error { - client, err := rpc.Dial("tcp", args.RPCAddress) + client, err := rpcDial(args.RPCAddress) if err != nil { return err }