packer/plugin: communicate over unix domain sockets if you can
This commit is contained in:
parent
8d9a6eef1b
commit
ae00414bbf
|
@ -34,7 +34,7 @@ type Client struct {
|
|||
exited bool
|
||||
doneLogging chan struct{}
|
||||
l sync.Mutex
|
||||
address string
|
||||
address net.Addr
|
||||
}
|
||||
|
||||
// ClientConfig is the configuration used to initialize a new
|
||||
|
@ -206,11 +206,11 @@ func (c *Client) Kill() {
|
|||
// This method is safe to call multiple times. Subsequent calls have no effect.
|
||||
// Once a client has been started once, it cannot be started again, even if
|
||||
// it was killed.
|
||||
func (c *Client) Start() (address string, err error) {
|
||||
func (c *Client) Start() (addr net.Addr, err error) {
|
||||
c.l.Lock()
|
||||
defer c.l.Unlock()
|
||||
|
||||
if c.address != "" {
|
||||
if c.address != nil {
|
||||
return c.address, nil
|
||||
}
|
||||
|
||||
|
@ -320,8 +320,8 @@ func (c *Client) Start() (address string, err error) {
|
|||
// Trim the line and split by "|" in order to get the parts of
|
||||
// the output.
|
||||
line := strings.TrimSpace(string(lineBytes))
|
||||
parts := strings.SplitN(line, "|", 2)
|
||||
if len(parts) < 2 {
|
||||
parts := strings.SplitN(line, "|", 3)
|
||||
if len(parts) < 3 {
|
||||
err = fmt.Errorf("Unrecognized remote plugin message: %s", line)
|
||||
return
|
||||
}
|
||||
|
@ -333,10 +333,17 @@ func (c *Client) Start() (address string, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
c.address = parts[1]
|
||||
address = c.address
|
||||
switch parts[1] {
|
||||
case "tcp":
|
||||
addr, err = net.ResolveTCPAddr("tcp", parts[2])
|
||||
case "unix":
|
||||
addr, err = net.ResolveUnixAddr("unix", parts[2])
|
||||
default:
|
||||
err = fmt.Errorf("Unknown address type: %s", parts[1])
|
||||
}
|
||||
}
|
||||
|
||||
c.address = addr
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -361,23 +368,24 @@ func (c *Client) logStderr(r io.Reader) {
|
|||
}
|
||||
|
||||
func (c *Client) packrpcClient() (*packrpc.Client, error) {
|
||||
address, err := c.Start()
|
||||
addr, err := c.Start()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn, err := net.Dial("tcp", address)
|
||||
conn, err := net.Dial(addr.Network(), addr.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
||||
// Make sure to set keep alive so that the connection doesn't die
|
||||
tcpConn := conn.(*net.TCPConn)
|
||||
tcpConn.SetKeepAlive(true)
|
||||
}
|
||||
|
||||
client, err := packrpc.NewClient(tcpConn)
|
||||
client, err := packrpc.NewClient(conn)
|
||||
if err != nil {
|
||||
tcpConn.Close()
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
@ -20,8 +20,12 @@ func TestClient(t *testing.T) {
|
|||
t.Fatalf("err should be nil, got %s", err)
|
||||
}
|
||||
|
||||
if addr != ":1234" {
|
||||
t.Fatalf("incorrect addr %s", addr)
|
||||
if addr.Network() != "tcp" {
|
||||
t.Fatalf("bad: %#v", addr)
|
||||
}
|
||||
|
||||
if addr.String() != ":1234" {
|
||||
t.Fatalf("bad: %#v", addr)
|
||||
}
|
||||
|
||||
// Test that it exits properly if killed
|
||||
|
|
|
@ -51,7 +51,7 @@ func TestHelperProcess(*testing.T) {
|
|||
cmd, args := args[0], args[1:]
|
||||
switch cmd {
|
||||
case "bad-version":
|
||||
fmt.Printf("%s1|:1234\n", APIVersion)
|
||||
fmt.Printf("%s1|tcp|:1234\n", APIVersion)
|
||||
<-make(chan int)
|
||||
case "builder":
|
||||
server, err := Server()
|
||||
|
@ -80,7 +80,7 @@ func TestHelperProcess(*testing.T) {
|
|||
case "invalid-rpc-address":
|
||||
fmt.Println("lolinvalid")
|
||||
case "mock":
|
||||
fmt.Printf("%s|:1234\n", APIVersion)
|
||||
fmt.Printf("%s|tcp|:1234\n", APIVersion)
|
||||
<-make(chan int)
|
||||
case "post-processor":
|
||||
server, err := Server()
|
||||
|
@ -102,11 +102,11 @@ func TestHelperProcess(*testing.T) {
|
|||
time.Sleep(1 * time.Minute)
|
||||
os.Exit(1)
|
||||
case "stderr":
|
||||
fmt.Printf("%s|:1234\n", APIVersion)
|
||||
fmt.Printf("%s|tcp|:1234\n", APIVersion)
|
||||
log.Println("HELLO")
|
||||
log.Println("WORLD")
|
||||
case "stdin":
|
||||
fmt.Printf("%s|:1234\n", APIVersion)
|
||||
fmt.Printf("%s|tcp|:1234\n", APIVersion)
|
||||
data := make([]byte, 5)
|
||||
if _, err := os.Stdin.Read(data); err != nil {
|
||||
log.Printf("stdin read error: %s", err)
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"fmt"
|
||||
"github.com/mitchellh/packer/packer"
|
||||
packrpc "github.com/mitchellh/packer/packer/rpc"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
|
@ -32,7 +33,7 @@ const MagicCookieValue = "d602bf8f470bc67ca7faa0386276bbdd4330efaf76d1a219cb4d69
|
|||
// The APIVersion is outputted along with the RPC address. The plugin
|
||||
// client validates this API version and will show an error if it doesn't
|
||||
// know how to speak it.
|
||||
const APIVersion = "1"
|
||||
const APIVersion = "2"
|
||||
|
||||
// Server waits for a connection to this plugin and returns a Packer
|
||||
// RPC server that you can use to register components and serve them.
|
||||
|
@ -62,23 +63,19 @@ func Server() (*packrpc.Server, error) {
|
|||
log.Printf("Plugin minimum port: %d\n", minPort)
|
||||
log.Printf("Plugin maximum port: %d\n", maxPort)
|
||||
|
||||
var address string
|
||||
var listener net.Listener
|
||||
for port := minPort; port <= maxPort; port++ {
|
||||
address = fmt.Sprintf("127.0.0.1:%d", port)
|
||||
listener, err = net.Listen("tcp", address)
|
||||
listener, err := serverListener(minPort, maxPort)
|
||||
if err != nil {
|
||||
err = nil
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
return nil, err
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
// Output the address to stdout
|
||||
log.Printf("Plugin address: %s\n", address)
|
||||
fmt.Printf("%s|%s\n", APIVersion, address)
|
||||
log.Printf("Plugin address: %s %s\n",
|
||||
listener.Addr().Network(), listener.Addr().String())
|
||||
fmt.Printf("%s|%s|%s\n",
|
||||
APIVersion,
|
||||
listener.Addr().Network(),
|
||||
listener.Addr().String())
|
||||
os.Stdout.Sync()
|
||||
|
||||
// Accept a connection
|
||||
|
@ -105,3 +102,42 @@ func Server() (*packrpc.Server, error) {
|
|||
log.Println("Serving a plugin connection...")
|
||||
return packrpc.NewServer(conn), nil
|
||||
}
|
||||
|
||||
func serverListener(minPort, maxPort int64) (net.Listener, error) {
|
||||
if runtime.GOOS == "windows" {
|
||||
return serverListener_tcp(minPort, maxPort)
|
||||
}
|
||||
|
||||
return serverListener_unix()
|
||||
}
|
||||
|
||||
func serverListener_tcp(minPort, maxPort int64) (net.Listener, error) {
|
||||
for port := minPort; port <= maxPort; port++ {
|
||||
address := fmt.Sprintf("127.0.0.1:%d", port)
|
||||
listener, err := net.Listen("tcp", address)
|
||||
if err == nil {
|
||||
return listener, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("Couldn't bind plugin TCP listener")
|
||||
}
|
||||
|
||||
func serverListener_unix() (net.Listener, error) {
|
||||
tf, err := ioutil.TempFile("", "packer-plugin")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
path := tf.Name()
|
||||
|
||||
// Close the file and remove it because it has to not exist for
|
||||
// the domain socket.
|
||||
if err := tf.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := os.Remove(path); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return net.Listen("unix", path)
|
||||
}
|
Loading…
Reference in New Issue