packer/plugin: communicate over unix domain sockets if you can

This commit is contained in:
Mitchell Hashimoto 2013-12-11 12:24:45 -08:00
parent 8d9a6eef1b
commit ae00414bbf
4 changed files with 82 additions and 34 deletions

View File

@ -34,7 +34,7 @@ type Client struct {
exited bool
doneLogging chan struct{}
l sync.Mutex
address string
address net.Addr
}
// ClientConfig is the configuration used to initialize a new
@ -206,11 +206,11 @@ func (c *Client) Kill() {
// This method is safe to call multiple times. Subsequent calls have no effect.
// Once a client has been started once, it cannot be started again, even if
// it was killed.
func (c *Client) Start() (address string, err error) {
func (c *Client) Start() (addr net.Addr, err error) {
c.l.Lock()
defer c.l.Unlock()
if c.address != "" {
if c.address != nil {
return c.address, nil
}
@ -320,8 +320,8 @@ func (c *Client) Start() (address string, err error) {
// Trim the line and split by "|" in order to get the parts of
// the output.
line := strings.TrimSpace(string(lineBytes))
parts := strings.SplitN(line, "|", 2)
if len(parts) < 2 {
parts := strings.SplitN(line, "|", 3)
if len(parts) < 3 {
err = fmt.Errorf("Unrecognized remote plugin message: %s", line)
return
}
@ -333,10 +333,17 @@ func (c *Client) Start() (address string, err error) {
return
}
c.address = parts[1]
address = c.address
switch parts[1] {
case "tcp":
addr, err = net.ResolveTCPAddr("tcp", parts[2])
case "unix":
addr, err = net.ResolveUnixAddr("unix", parts[2])
default:
err = fmt.Errorf("Unknown address type: %s", parts[1])
}
}
c.address = addr
return
}
@ -361,23 +368,24 @@ func (c *Client) logStderr(r io.Reader) {
}
func (c *Client) packrpcClient() (*packrpc.Client, error) {
address, err := c.Start()
addr, err := c.Start()
if err != nil {
return nil, err
}
conn, err := net.Dial("tcp", address)
conn, err := net.Dial(addr.Network(), addr.String())
if err != nil {
return nil, err
}
// Make sure to set keep alive so that the connection doesn't die
tcpConn := conn.(*net.TCPConn)
tcpConn.SetKeepAlive(true)
if tcpConn, ok := conn.(*net.TCPConn); ok {
// Make sure to set keep alive so that the connection doesn't die
tcpConn.SetKeepAlive(true)
}
client, err := packrpc.NewClient(tcpConn)
client, err := packrpc.NewClient(conn)
if err != nil {
tcpConn.Close()
conn.Close()
return nil, err
}

View File

@ -20,8 +20,12 @@ func TestClient(t *testing.T) {
t.Fatalf("err should be nil, got %s", err)
}
if addr != ":1234" {
t.Fatalf("incorrect addr %s", addr)
if addr.Network() != "tcp" {
t.Fatalf("bad: %#v", addr)
}
if addr.String() != ":1234" {
t.Fatalf("bad: %#v", addr)
}
// Test that it exits properly if killed

View File

@ -51,7 +51,7 @@ func TestHelperProcess(*testing.T) {
cmd, args := args[0], args[1:]
switch cmd {
case "bad-version":
fmt.Printf("%s1|:1234\n", APIVersion)
fmt.Printf("%s1|tcp|:1234\n", APIVersion)
<-make(chan int)
case "builder":
server, err := Server()
@ -80,7 +80,7 @@ func TestHelperProcess(*testing.T) {
case "invalid-rpc-address":
fmt.Println("lolinvalid")
case "mock":
fmt.Printf("%s|:1234\n", APIVersion)
fmt.Printf("%s|tcp|:1234\n", APIVersion)
<-make(chan int)
case "post-processor":
server, err := Server()
@ -102,11 +102,11 @@ func TestHelperProcess(*testing.T) {
time.Sleep(1 * time.Minute)
os.Exit(1)
case "stderr":
fmt.Printf("%s|:1234\n", APIVersion)
fmt.Printf("%s|tcp|:1234\n", APIVersion)
log.Println("HELLO")
log.Println("WORLD")
case "stdin":
fmt.Printf("%s|:1234\n", APIVersion)
fmt.Printf("%s|tcp|:1234\n", APIVersion)
data := make([]byte, 5)
if _, err := os.Stdin.Read(data); err != nil {
log.Printf("stdin read error: %s", err)

View File

@ -12,6 +12,7 @@ import (
"fmt"
"github.com/mitchellh/packer/packer"
packrpc "github.com/mitchellh/packer/packer/rpc"
"io/ioutil"
"log"
"net"
"os"
@ -32,7 +33,7 @@ const MagicCookieValue = "d602bf8f470bc67ca7faa0386276bbdd4330efaf76d1a219cb4d69
// The APIVersion is outputted along with the RPC address. The plugin
// client validates this API version and will show an error if it doesn't
// know how to speak it.
const APIVersion = "1"
const APIVersion = "2"
// Server waits for a connection to this plugin and returns a Packer
// RPC server that you can use to register components and serve them.
@ -62,23 +63,19 @@ func Server() (*packrpc.Server, error) {
log.Printf("Plugin minimum port: %d\n", minPort)
log.Printf("Plugin maximum port: %d\n", maxPort)
var address string
var listener net.Listener
for port := minPort; port <= maxPort; port++ {
address = fmt.Sprintf("127.0.0.1:%d", port)
listener, err = net.Listen("tcp", address)
if err != nil {
err = nil
continue
}
break
listener, err := serverListener(minPort, maxPort)
if err != nil {
return nil, err
}
defer listener.Close()
// Output the address to stdout
log.Printf("Plugin address: %s\n", address)
fmt.Printf("%s|%s\n", APIVersion, address)
log.Printf("Plugin address: %s %s\n",
listener.Addr().Network(), listener.Addr().String())
fmt.Printf("%s|%s|%s\n",
APIVersion,
listener.Addr().Network(),
listener.Addr().String())
os.Stdout.Sync()
// Accept a connection
@ -105,3 +102,42 @@ func Server() (*packrpc.Server, error) {
log.Println("Serving a plugin connection...")
return packrpc.NewServer(conn), nil
}
func serverListener(minPort, maxPort int64) (net.Listener, error) {
if runtime.GOOS == "windows" {
return serverListener_tcp(minPort, maxPort)
}
return serverListener_unix()
}
func serverListener_tcp(minPort, maxPort int64) (net.Listener, error) {
for port := minPort; port <= maxPort; port++ {
address := fmt.Sprintf("127.0.0.1:%d", port)
listener, err := net.Listen("tcp", address)
if err == nil {
return listener, nil
}
}
return nil, errors.New("Couldn't bind plugin TCP listener")
}
func serverListener_unix() (net.Listener, error) {
tf, err := ioutil.TempFile("", "packer-plugin")
if err != nil {
return nil, err
}
path := tf.Name()
// Close the file and remove it because it has to not exist for
// the domain socket.
if err := tf.Close(); err != nil {
return nil, err
}
if err := os.Remove(path); err != nil {
return nil, err
}
return net.Listen("unix", path)
}