packer/plugin: communicate over unix domain sockets if you can
This commit is contained in:
parent
8d9a6eef1b
commit
ae00414bbf
|
@ -34,7 +34,7 @@ type Client struct {
|
||||||
exited bool
|
exited bool
|
||||||
doneLogging chan struct{}
|
doneLogging chan struct{}
|
||||||
l sync.Mutex
|
l sync.Mutex
|
||||||
address string
|
address net.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClientConfig is the configuration used to initialize a new
|
// ClientConfig is the configuration used to initialize a new
|
||||||
|
@ -206,11 +206,11 @@ func (c *Client) Kill() {
|
||||||
// This method is safe to call multiple times. Subsequent calls have no effect.
|
// This method is safe to call multiple times. Subsequent calls have no effect.
|
||||||
// Once a client has been started once, it cannot be started again, even if
|
// Once a client has been started once, it cannot be started again, even if
|
||||||
// it was killed.
|
// it was killed.
|
||||||
func (c *Client) Start() (address string, err error) {
|
func (c *Client) Start() (addr net.Addr, err error) {
|
||||||
c.l.Lock()
|
c.l.Lock()
|
||||||
defer c.l.Unlock()
|
defer c.l.Unlock()
|
||||||
|
|
||||||
if c.address != "" {
|
if c.address != nil {
|
||||||
return c.address, nil
|
return c.address, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -320,8 +320,8 @@ func (c *Client) Start() (address string, err error) {
|
||||||
// Trim the line and split by "|" in order to get the parts of
|
// Trim the line and split by "|" in order to get the parts of
|
||||||
// the output.
|
// the output.
|
||||||
line := strings.TrimSpace(string(lineBytes))
|
line := strings.TrimSpace(string(lineBytes))
|
||||||
parts := strings.SplitN(line, "|", 2)
|
parts := strings.SplitN(line, "|", 3)
|
||||||
if len(parts) < 2 {
|
if len(parts) < 3 {
|
||||||
err = fmt.Errorf("Unrecognized remote plugin message: %s", line)
|
err = fmt.Errorf("Unrecognized remote plugin message: %s", line)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -333,10 +333,17 @@ func (c *Client) Start() (address string, err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.address = parts[1]
|
switch parts[1] {
|
||||||
address = c.address
|
case "tcp":
|
||||||
|
addr, err = net.ResolveTCPAddr("tcp", parts[2])
|
||||||
|
case "unix":
|
||||||
|
addr, err = net.ResolveUnixAddr("unix", parts[2])
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("Unknown address type: %s", parts[1])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.address = addr
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -361,23 +368,24 @@ func (c *Client) logStderr(r io.Reader) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) packrpcClient() (*packrpc.Client, error) {
|
func (c *Client) packrpcClient() (*packrpc.Client, error) {
|
||||||
address, err := c.Start()
|
addr, err := c.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := net.Dial("tcp", address)
|
conn, err := net.Dial(addr.Network(), addr.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure to set keep alive so that the connection doesn't die
|
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
||||||
tcpConn := conn.(*net.TCPConn)
|
// Make sure to set keep alive so that the connection doesn't die
|
||||||
tcpConn.SetKeepAlive(true)
|
tcpConn.SetKeepAlive(true)
|
||||||
|
}
|
||||||
|
|
||||||
client, err := packrpc.NewClient(tcpConn)
|
client, err := packrpc.NewClient(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tcpConn.Close()
|
conn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,8 +20,12 @@ func TestClient(t *testing.T) {
|
||||||
t.Fatalf("err should be nil, got %s", err)
|
t.Fatalf("err should be nil, got %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if addr != ":1234" {
|
if addr.Network() != "tcp" {
|
||||||
t.Fatalf("incorrect addr %s", addr)
|
t.Fatalf("bad: %#v", addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if addr.String() != ":1234" {
|
||||||
|
t.Fatalf("bad: %#v", addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test that it exits properly if killed
|
// Test that it exits properly if killed
|
||||||
|
|
|
@ -51,7 +51,7 @@ func TestHelperProcess(*testing.T) {
|
||||||
cmd, args := args[0], args[1:]
|
cmd, args := args[0], args[1:]
|
||||||
switch cmd {
|
switch cmd {
|
||||||
case "bad-version":
|
case "bad-version":
|
||||||
fmt.Printf("%s1|:1234\n", APIVersion)
|
fmt.Printf("%s1|tcp|:1234\n", APIVersion)
|
||||||
<-make(chan int)
|
<-make(chan int)
|
||||||
case "builder":
|
case "builder":
|
||||||
server, err := Server()
|
server, err := Server()
|
||||||
|
@ -80,7 +80,7 @@ func TestHelperProcess(*testing.T) {
|
||||||
case "invalid-rpc-address":
|
case "invalid-rpc-address":
|
||||||
fmt.Println("lolinvalid")
|
fmt.Println("lolinvalid")
|
||||||
case "mock":
|
case "mock":
|
||||||
fmt.Printf("%s|:1234\n", APIVersion)
|
fmt.Printf("%s|tcp|:1234\n", APIVersion)
|
||||||
<-make(chan int)
|
<-make(chan int)
|
||||||
case "post-processor":
|
case "post-processor":
|
||||||
server, err := Server()
|
server, err := Server()
|
||||||
|
@ -102,11 +102,11 @@ func TestHelperProcess(*testing.T) {
|
||||||
time.Sleep(1 * time.Minute)
|
time.Sleep(1 * time.Minute)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
case "stderr":
|
case "stderr":
|
||||||
fmt.Printf("%s|:1234\n", APIVersion)
|
fmt.Printf("%s|tcp|:1234\n", APIVersion)
|
||||||
log.Println("HELLO")
|
log.Println("HELLO")
|
||||||
log.Println("WORLD")
|
log.Println("WORLD")
|
||||||
case "stdin":
|
case "stdin":
|
||||||
fmt.Printf("%s|:1234\n", APIVersion)
|
fmt.Printf("%s|tcp|:1234\n", APIVersion)
|
||||||
data := make([]byte, 5)
|
data := make([]byte, 5)
|
||||||
if _, err := os.Stdin.Read(data); err != nil {
|
if _, err := os.Stdin.Read(data); err != nil {
|
||||||
log.Printf("stdin read error: %s", err)
|
log.Printf("stdin read error: %s", err)
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/mitchellh/packer/packer"
|
"github.com/mitchellh/packer/packer"
|
||||||
packrpc "github.com/mitchellh/packer/packer/rpc"
|
packrpc "github.com/mitchellh/packer/packer/rpc"
|
||||||
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
@ -32,7 +33,7 @@ const MagicCookieValue = "d602bf8f470bc67ca7faa0386276bbdd4330efaf76d1a219cb4d69
|
||||||
// The APIVersion is outputted along with the RPC address. The plugin
|
// The APIVersion is outputted along with the RPC address. The plugin
|
||||||
// client validates this API version and will show an error if it doesn't
|
// client validates this API version and will show an error if it doesn't
|
||||||
// know how to speak it.
|
// know how to speak it.
|
||||||
const APIVersion = "1"
|
const APIVersion = "2"
|
||||||
|
|
||||||
// Server waits for a connection to this plugin and returns a Packer
|
// Server waits for a connection to this plugin and returns a Packer
|
||||||
// RPC server that you can use to register components and serve them.
|
// RPC server that you can use to register components and serve them.
|
||||||
|
@ -62,23 +63,19 @@ func Server() (*packrpc.Server, error) {
|
||||||
log.Printf("Plugin minimum port: %d\n", minPort)
|
log.Printf("Plugin minimum port: %d\n", minPort)
|
||||||
log.Printf("Plugin maximum port: %d\n", maxPort)
|
log.Printf("Plugin maximum port: %d\n", maxPort)
|
||||||
|
|
||||||
var address string
|
listener, err := serverListener(minPort, maxPort)
|
||||||
var listener net.Listener
|
if err != nil {
|
||||||
for port := minPort; port <= maxPort; port++ {
|
return nil, err
|
||||||
address = fmt.Sprintf("127.0.0.1:%d", port)
|
|
||||||
listener, err = net.Listen("tcp", address)
|
|
||||||
if err != nil {
|
|
||||||
err = nil
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
defer listener.Close()
|
defer listener.Close()
|
||||||
|
|
||||||
// Output the address to stdout
|
// Output the address to stdout
|
||||||
log.Printf("Plugin address: %s\n", address)
|
log.Printf("Plugin address: %s %s\n",
|
||||||
fmt.Printf("%s|%s\n", APIVersion, address)
|
listener.Addr().Network(), listener.Addr().String())
|
||||||
|
fmt.Printf("%s|%s|%s\n",
|
||||||
|
APIVersion,
|
||||||
|
listener.Addr().Network(),
|
||||||
|
listener.Addr().String())
|
||||||
os.Stdout.Sync()
|
os.Stdout.Sync()
|
||||||
|
|
||||||
// Accept a connection
|
// Accept a connection
|
||||||
|
@ -105,3 +102,42 @@ func Server() (*packrpc.Server, error) {
|
||||||
log.Println("Serving a plugin connection...")
|
log.Println("Serving a plugin connection...")
|
||||||
return packrpc.NewServer(conn), nil
|
return packrpc.NewServer(conn), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func serverListener(minPort, maxPort int64) (net.Listener, error) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
return serverListener_tcp(minPort, maxPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
return serverListener_unix()
|
||||||
|
}
|
||||||
|
|
||||||
|
func serverListener_tcp(minPort, maxPort int64) (net.Listener, error) {
|
||||||
|
for port := minPort; port <= maxPort; port++ {
|
||||||
|
address := fmt.Sprintf("127.0.0.1:%d", port)
|
||||||
|
listener, err := net.Listen("tcp", address)
|
||||||
|
if err == nil {
|
||||||
|
return listener, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.New("Couldn't bind plugin TCP listener")
|
||||||
|
}
|
||||||
|
|
||||||
|
func serverListener_unix() (net.Listener, error) {
|
||||||
|
tf, err := ioutil.TempFile("", "packer-plugin")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
path := tf.Name()
|
||||||
|
|
||||||
|
// Close the file and remove it because it has to not exist for
|
||||||
|
// the domain socket.
|
||||||
|
if err := tf.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := os.Remove(path); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return net.Listen("unix", path)
|
||||||
|
}
|
Loading…
Reference in New Issue