diff --git a/packer/plugin/client.go b/packer/plugin/client.go index 792f9d99d..3fefca15f 100644 --- a/packer/plugin/client.go +++ b/packer/plugin/client.go @@ -10,7 +10,6 @@ import ( "io/ioutil" "log" "net" - "net/rpc" "os" "os/exec" "strings" @@ -130,56 +129,56 @@ func (c *Client) Exited() bool { // Returns a builder implementation that is communicating over this // client. If the client hasn't been started, this will start it. func (c *Client) Builder() (packer.Builder, error) { - client, err := c.rpcClient() + client, err := c.packrpcClient() if err != nil { return nil, err } - return &cmdBuilder{packrpc.Builder(client), c}, nil + return &cmdBuilder{client.Builder(), c}, nil } // Returns a command implementation that is communicating over this // client. If the client hasn't been started, this will start it. func (c *Client) Command() (packer.Command, error) { - client, err := c.rpcClient() + client, err := c.packrpcClient() if err != nil { return nil, err } - return &cmdCommand{packrpc.Command(client), c}, nil + return &cmdCommand{client.Command(), c}, nil } // Returns a hook implementation that is communicating over this // client. If the client hasn't been started, this will start it. func (c *Client) Hook() (packer.Hook, error) { - client, err := c.rpcClient() + client, err := c.packrpcClient() if err != nil { return nil, err } - return &cmdHook{packrpc.Hook(client), c}, nil + return &cmdHook{client.Hook(), c}, nil } // Returns a post-processor implementation that is communicating over // this client. If the client hasn't been started, this will start it. func (c *Client) PostProcessor() (packer.PostProcessor, error) { - client, err := c.rpcClient() + client, err := c.packrpcClient() if err != nil { return nil, err } - return &cmdPostProcessor{packrpc.PostProcessor(client), c}, nil + return &cmdPostProcessor{client.PostProcessor(), c}, nil } // Returns a provisioner implementation that is communicating over this // client. If the client hasn't been started, this will start it. func (c *Client) Provisioner() (packer.Provisioner, error) { - client, err := c.rpcClient() + client, err := c.packrpcClient() if err != nil { return nil, err } - return &cmdProvisioner{packrpc.Provisioner(client), c}, nil + return &cmdProvisioner{client.Provisioner(), c}, nil } // End the executing subprocess (if it is running) and perform any cleanup @@ -361,7 +360,7 @@ func (c *Client) logStderr(r io.Reader) { close(c.doneLogging) } -func (c *Client) rpcClient() (*rpc.Client, error) { +func (c *Client) packrpcClient() (*packrpc.Client, error) { address, err := c.Start() if err != nil { return nil, err @@ -376,5 +375,11 @@ func (c *Client) rpcClient() (*rpc.Client, error) { tcpConn := conn.(*net.TCPConn) tcpConn.SetKeepAlive(true) - return rpc.NewClient(tcpConn), nil + client, err := packrpc.NewClient(tcpConn) + if err != nil { + tcpConn.Close() + return nil, err + } + + return client, nil } diff --git a/packer/plugin/plugin.go b/packer/plugin/plugin.go index a91fcc3ce..91aa91274 100644 --- a/packer/plugin/plugin.go +++ b/packer/plugin/plugin.go @@ -14,7 +14,6 @@ import ( packrpc "github.com/mitchellh/packer/packer/rpc" "log" "net" - "net/rpc" "os" "os/signal" "runtime" @@ -35,13 +34,14 @@ const MagicCookieValue = "d602bf8f470bc67ca7faa0386276bbdd4330efaf76d1a219cb4d69 // know how to speak it. const APIVersion = "1" -// This serves a single RPC connection on the given RPC server on -// a random port. -func serve(server *rpc.Server) (err error) { +// Server waits for a connection to this plugin and returns a Packer +// RPC server that you can use to register components and serve them. +func Server() (*packrpc.Server, error) { log.Printf("Plugin build against Packer '%s'", packer.GitCommit) if os.Getenv(MagicCookieKey) != MagicCookieValue { - return errors.New("Please do not execute plugins directly. Packer will execute these for you.") + return nil, errors.New( + "Please do not execute plugins directly. Packer will execute these for you.") } // If there is no explicit number of Go threads to use, then set it @@ -51,12 +51,12 @@ func serve(server *rpc.Server) (err error) { minPort, err := strconv.ParseInt(os.Getenv("PACKER_PLUGIN_MIN_PORT"), 10, 32) if err != nil { - return + return nil, err } maxPort, err := strconv.ParseInt(os.Getenv("PACKER_PLUGIN_MAX_PORT"), 10, 32) if err != nil { - return + return nil, err } log.Printf("Plugin minimum port: %d\n", minPort) @@ -77,7 +77,6 @@ func serve(server *rpc.Server) (err error) { break } - defer listener.Close() // Output the address to stdout @@ -90,13 +89,12 @@ func serve(server *rpc.Server) (err error) { conn, err := listener.Accept() if err != nil { log.Printf("Error accepting connection: %s\n", err.Error()) - return + return nil, err } // Serve a single connection log.Println("Serving a plugin connection...") - server.ServeConn(conn) - return + return packrpc.NewServer(conn), nil } // Registers a signal handler to swallow and count interrupts so that the @@ -115,76 +113,6 @@ func countInterrupts() { }() } -// Serves a builder from a plugin. -func ServeBuilder(builder packer.Builder) { - log.Println("Preparing to serve a builder plugin...") - - server := rpc.NewServer() - packrpc.RegisterBuilder(server, builder) - - countInterrupts() - if err := serve(server); err != nil { - log.Printf("ERROR: %s", err) - os.Exit(1) - } -} - -// Serves a command from a plugin. -func ServeCommand(command packer.Command) { - log.Println("Preparing to serve a command plugin...") - - server := rpc.NewServer() - packrpc.RegisterCommand(server, command) - - countInterrupts() - if err := serve(server); err != nil { - log.Printf("ERROR: %s", err) - os.Exit(1) - } -} - -// Serves a hook from a plugin. -func ServeHook(hook packer.Hook) { - log.Println("Preparing to serve a hook plugin...") - - server := rpc.NewServer() - packrpc.RegisterHook(server, hook) - - countInterrupts() - if err := serve(server); err != nil { - log.Printf("ERROR: %s", err) - os.Exit(1) - } -} - -// Serves a post-processor from a plugin. -func ServePostProcessor(p packer.PostProcessor) { - log.Println("Preparing to serve a post-processor plugin...") - - server := rpc.NewServer() - packrpc.RegisterPostProcessor(server, p) - - countInterrupts() - if err := serve(server); err != nil { - log.Printf("ERROR: %s", err) - os.Exit(1) - } -} - -// Serves a provisioner from a plugin. -func ServeProvisioner(p packer.Provisioner) { - log.Println("Preparing to serve a provisioner plugin...") - - server := rpc.NewServer() - packrpc.RegisterProvisioner(server, p) - - countInterrupts() - if err := serve(server); err != nil { - log.Printf("ERROR: %s", err) - os.Exit(1) - } -} - // Tests whether or not the plugin was interrupted or not. func Interrupted() bool { return atomic.LoadInt32(&Interrupts) > 0 diff --git a/packer/plugin/plugin_test.go b/packer/plugin/plugin_test.go index d80aceaf7..733190ec3 100644 --- a/packer/plugin/plugin_test.go +++ b/packer/plugin/plugin_test.go @@ -54,20 +54,50 @@ func TestHelperProcess(*testing.T) { fmt.Printf("%s1|:1234\n", APIVersion) <-make(chan int) case "builder": - ServeBuilder(new(packer.MockBuilder)) + server, err := Server() + if err != nil { + log.Printf("[ERR] %s", err) + os.Exit(1) + } + server.RegisterBuilder(new(packer.MockBuilder)) + server.Serve() case "command": - ServeCommand(new(helperCommand)) + server, err := Server() + if err != nil { + log.Printf("[ERR] %s", err) + os.Exit(1) + } + server.RegisterCommand(new(helperCommand)) + server.Serve() case "hook": - ServeHook(new(packer.MockHook)) + server, err := Server() + if err != nil { + log.Printf("[ERR] %s", err) + os.Exit(1) + } + server.RegisterHook(new(packer.MockHook)) + server.Serve() case "invalid-rpc-address": fmt.Println("lolinvalid") case "mock": fmt.Printf("%s|:1234\n", APIVersion) <-make(chan int) case "post-processor": - ServePostProcessor(new(helperPostProcessor)) + server, err := Server() + if err != nil { + log.Printf("[ERR] %s", err) + os.Exit(1) + } + server.RegisterPostProcessor(new(helperPostProcessor)) + server.Serve() case "provisioner": - ServeProvisioner(new(packer.MockProvisioner)) + server, err := Server() + if err != nil { + log.Printf("[ERR] %s", err) + os.Exit(1) + } + server.RegisterProvisioner(new(packer.MockProvisioner)) + server.Serve() case "start-timeout": time.Sleep(1 * time.Minute) os.Exit(1)