diff --git a/packer/rpc/client.go b/packer/rpc/client.go index c201ebe19..83783d79b 100644 --- a/packer/rpc/client.go +++ b/packer/rpc/client.go @@ -72,6 +72,13 @@ func (c *Client) PostProcessor() packer.PostProcessor { } } +func (c *Client) Provisioner() packer.Provisioner { + return &provisioner{ + client: c.client, + mux: c.mux, + } +} + func (c *Client) Ui() packer.Ui { return &Ui{ client: c.client, diff --git a/packer/rpc/provisioner.go b/packer/rpc/provisioner.go index a62f97d45..eafd65756 100644 --- a/packer/rpc/provisioner.go +++ b/packer/rpc/provisioner.go @@ -10,24 +10,22 @@ import ( // executed over an RPC connection. type provisioner struct { client *rpc.Client + mux *MuxConn } // ProvisionerServer wraps a packer.Provisioner implementation and makes it // exportable as part of a Golang RPC server. type ProvisionerServer struct { - p packer.Provisioner + p packer.Provisioner + mux *MuxConn } type ProvisionerPrepareArgs struct { Configs []interface{} } -type ProvisionerProvisionArgs struct { - RPCAddress string -} - func Provisioner(client *rpc.Client) *provisioner { - return &provisioner{client} + return &provisioner{client: client} } func (p *provisioner) Prepare(configs ...interface{}) (err error) { args := &ProvisionerPrepareArgs{configs} @@ -39,13 +37,13 @@ func (p *provisioner) Prepare(configs ...interface{}) (err error) { } func (p *provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { - // TODO: Error handling - server := rpc.NewServer() - RegisterCommunicator(server, comm) - RegisterUi(server, ui) + nextId := p.mux.NextId() + server := NewServerWithMux(p.mux, nextId) + server.RegisterCommunicator(comm) + server.RegisterUi(ui) + go server.Serve() - args := &ProvisionerProvisionArgs{serveSingleConn(server)} - return p.client.Call("Provisioner.Provision", args, new(interface{})) + return p.client.Call("Provisioner.Provision", nextId, new(interface{})) } func (p *provisioner) Cancel() { @@ -64,16 +62,14 @@ func (p *ProvisionerServer) Prepare(args *ProvisionerPrepareArgs, reply *error) return nil } -func (p *ProvisionerServer) Provision(args *ProvisionerProvisionArgs, reply *interface{}) error { - client, err := rpcDial(args.RPCAddress) +func (p *ProvisionerServer) Provision(streamId uint32, reply *interface{}) error { + client, err := NewClientWithMux(p.mux, streamId) if err != nil { - return err + return NewBasicError(err) } + defer client.Close() - comm := Communicator(client) - ui := &Ui{client: client} - - if err := p.p.Provision(ui, comm); err != nil { + if err := p.p.Provision(client.Ui(), client.Communicator()); err != nil { return NewBasicError(err) } diff --git a/packer/rpc/provisioner_test.go b/packer/rpc/provisioner_test.go index 7e281a1b5..5d1604b74 100644 --- a/packer/rpc/provisioner_test.go +++ b/packer/rpc/provisioner_test.go @@ -2,7 +2,6 @@ package rpc import ( "github.com/mitchellh/packer/packer" - "net/rpc" "reflect" "testing" ) @@ -12,19 +11,14 @@ func TestProvisionerRPC(t *testing.T) { p := new(packer.MockProvisioner) // Start the server - server := rpc.NewServer() - RegisterProvisioner(server, p) - address := serveSingleConn(server) - - // Create the client over RPC and run some methods to verify it works - client, err := rpc.Dial("tcp", address) - if err != nil { - t.Fatalf("err: %s", err) - } + client, server := testClientServer(t) + defer client.Close() + defer server.Close() + server.RegisterProvisioner(p) + pClient := client.Provisioner() // Test Prepare config := 42 - pClient := Provisioner(client) pClient.Prepare(config) if !p.PrepCalled { t.Fatal("should be called") @@ -41,11 +35,6 @@ func TestProvisionerRPC(t *testing.T) { t.Fatal("should be called") } - p.ProvUi.Say("foo") - if !ui.sayCalled { - t.Fatal("should be called") - } - // Test Cancel pClient.Cancel() if !p.CancelCalled { diff --git a/packer/rpc/server.go b/packer/rpc/server.go index fc29c7e91..904a1b1aa 100644 --- a/packer/rpc/server.go +++ b/packer/rpc/server.go @@ -61,7 +61,7 @@ func RegisterPostProcessor(s *rpc.Server, p packer.PostProcessor) { // Registers the appropriate endpoint on an RPC server to serve a packer.Provisioner func RegisterProvisioner(s *rpc.Server, p packer.Provisioner) { - registerComponent(s, "Provisioner", &ProvisionerServer{p}, false) + registerComponent(s, "Provisioner", &ProvisionerServer{p: p}, false) } // Registers the appropriate endpoint on an RPC server to serve a diff --git a/packer/rpc/server_new.go b/packer/rpc/server_new.go index b8cb9fffa..b80f6e441 100644 --- a/packer/rpc/server_new.go +++ b/packer/rpc/server_new.go @@ -17,6 +17,7 @@ const ( DefaultCommunicatorEndpoint = "Communicator" DefaultHookEndpoint = "Hook" DefaultPostProcessorEndpoint = "PostProcessor" + DefaultProvisionerEndpoint = "Provisioner" DefaultUiEndpoint = "Ui" ) @@ -78,6 +79,13 @@ func (s *Server) RegisterPostProcessor(p packer.PostProcessor) { }) } +func (s *Server) RegisterProvisioner(p packer.Provisioner) { + s.server.RegisterName(DefaultProvisionerEndpoint, &ProvisionerServer{ + mux: s.mux, + p: p, + }) +} + func (s *Server) RegisterUi(ui packer.Ui) { s.server.RegisterName(DefaultUiEndpoint, &UiServer{ ui: ui,