diff --git a/packer/plugin/client.go b/packer/plugin/client.go index 792f9d99d..3fefca15f 100644 --- a/packer/plugin/client.go +++ b/packer/plugin/client.go @@ -10,7 +10,6 @@ import ( "io/ioutil" "log" "net" - "net/rpc" "os" "os/exec" "strings" @@ -130,56 +129,56 @@ func (c *Client) Exited() bool { // Returns a builder implementation that is communicating over this // client. If the client hasn't been started, this will start it. func (c *Client) Builder() (packer.Builder, error) { - client, err := c.rpcClient() + client, err := c.packrpcClient() if err != nil { return nil, err } - return &cmdBuilder{packrpc.Builder(client), c}, nil + return &cmdBuilder{client.Builder(), c}, nil } // Returns a command implementation that is communicating over this // client. If the client hasn't been started, this will start it. func (c *Client) Command() (packer.Command, error) { - client, err := c.rpcClient() + client, err := c.packrpcClient() if err != nil { return nil, err } - return &cmdCommand{packrpc.Command(client), c}, nil + return &cmdCommand{client.Command(), c}, nil } // Returns a hook implementation that is communicating over this // client. If the client hasn't been started, this will start it. func (c *Client) Hook() (packer.Hook, error) { - client, err := c.rpcClient() + client, err := c.packrpcClient() if err != nil { return nil, err } - return &cmdHook{packrpc.Hook(client), c}, nil + return &cmdHook{client.Hook(), c}, nil } // Returns a post-processor implementation that is communicating over // this client. If the client hasn't been started, this will start it. func (c *Client) PostProcessor() (packer.PostProcessor, error) { - client, err := c.rpcClient() + client, err := c.packrpcClient() if err != nil { return nil, err } - return &cmdPostProcessor{packrpc.PostProcessor(client), c}, nil + return &cmdPostProcessor{client.PostProcessor(), c}, nil } // Returns a provisioner implementation that is communicating over this // client. If the client hasn't been started, this will start it. func (c *Client) Provisioner() (packer.Provisioner, error) { - client, err := c.rpcClient() + client, err := c.packrpcClient() if err != nil { return nil, err } - return &cmdProvisioner{packrpc.Provisioner(client), c}, nil + return &cmdProvisioner{client.Provisioner(), c}, nil } // End the executing subprocess (if it is running) and perform any cleanup @@ -361,7 +360,7 @@ func (c *Client) logStderr(r io.Reader) { close(c.doneLogging) } -func (c *Client) rpcClient() (*rpc.Client, error) { +func (c *Client) packrpcClient() (*packrpc.Client, error) { address, err := c.Start() if err != nil { return nil, err @@ -376,5 +375,11 @@ func (c *Client) rpcClient() (*rpc.Client, error) { tcpConn := conn.(*net.TCPConn) tcpConn.SetKeepAlive(true) - return rpc.NewClient(tcpConn), nil + client, err := packrpc.NewClient(tcpConn) + if err != nil { + tcpConn.Close() + return nil, err + } + + return client, nil } diff --git a/packer/plugin/plugin.go b/packer/plugin/plugin.go index a91fcc3ce..8012ffc4f 100644 --- a/packer/plugin/plugin.go +++ b/packer/plugin/plugin.go @@ -14,7 +14,6 @@ import ( packrpc "github.com/mitchellh/packer/packer/rpc" "log" "net" - "net/rpc" "os" "os/signal" "runtime" @@ -35,13 +34,14 @@ const MagicCookieValue = "d602bf8f470bc67ca7faa0386276bbdd4330efaf76d1a219cb4d69 // know how to speak it. const APIVersion = "1" -// This serves a single RPC connection on the given RPC server on -// a random port. -func serve(server *rpc.Server) (err error) { +// Server waits for a connection to this plugin and returns a Packer +// RPC server that you can use to register components and serve them. +func Server() (*packrpc.Server, error) { log.Printf("Plugin build against Packer '%s'", packer.GitCommit) if os.Getenv(MagicCookieKey) != MagicCookieValue { - return errors.New("Please do not execute plugins directly. Packer will execute these for you.") + return nil, errors.New( + "Please do not execute plugins directly. Packer will execute these for you.") } // If there is no explicit number of Go threads to use, then set it @@ -51,12 +51,12 @@ func serve(server *rpc.Server) (err error) { minPort, err := strconv.ParseInt(os.Getenv("PACKER_PLUGIN_MIN_PORT"), 10, 32) if err != nil { - return + return nil, err } maxPort, err := strconv.ParseInt(os.Getenv("PACKER_PLUGIN_MAX_PORT"), 10, 32) if err != nil { - return + return nil, err } log.Printf("Plugin minimum port: %d\n", minPort) @@ -77,7 +77,6 @@ func serve(server *rpc.Server) (err error) { break } - defer listener.Close() // Output the address to stdout @@ -90,102 +89,22 @@ func serve(server *rpc.Server) (err error) { conn, err := listener.Accept() if err != nil { log.Printf("Error accepting connection: %s\n", err.Error()) - return + return nil, err } - // Serve a single connection - log.Println("Serving a plugin connection...") - server.ServeConn(conn) - return -} - -// Registers a signal handler to swallow and count interrupts so that the -// plugin isn't killed. The main host Packer process is responsible -// for killing the plugins when interrupted. -func countInterrupts() { + // Eat the interrupts ch := make(chan os.Signal, 1) signal.Notify(ch, os.Interrupt) - go func() { + var count int32 = 0 for { <-ch - newCount := atomic.AddInt32(&Interrupts, 1) + newCount := atomic.AddInt32(&count, 1) log.Printf("Received interrupt signal (count: %d). Ignoring.", newCount) } }() -} - -// Serves a builder from a plugin. -func ServeBuilder(builder packer.Builder) { - log.Println("Preparing to serve a builder plugin...") - - server := rpc.NewServer() - packrpc.RegisterBuilder(server, builder) - - countInterrupts() - if err := serve(server); err != nil { - log.Printf("ERROR: %s", err) - os.Exit(1) - } -} - -// Serves a command from a plugin. -func ServeCommand(command packer.Command) { - log.Println("Preparing to serve a command plugin...") - - server := rpc.NewServer() - packrpc.RegisterCommand(server, command) - - countInterrupts() - if err := serve(server); err != nil { - log.Printf("ERROR: %s", err) - os.Exit(1) - } -} - -// Serves a hook from a plugin. -func ServeHook(hook packer.Hook) { - log.Println("Preparing to serve a hook plugin...") - - server := rpc.NewServer() - packrpc.RegisterHook(server, hook) - - countInterrupts() - if err := serve(server); err != nil { - log.Printf("ERROR: %s", err) - os.Exit(1) - } -} - -// Serves a post-processor from a plugin. -func ServePostProcessor(p packer.PostProcessor) { - log.Println("Preparing to serve a post-processor plugin...") - - server := rpc.NewServer() - packrpc.RegisterPostProcessor(server, p) - - countInterrupts() - if err := serve(server); err != nil { - log.Printf("ERROR: %s", err) - os.Exit(1) - } -} - -// Serves a provisioner from a plugin. -func ServeProvisioner(p packer.Provisioner) { - log.Println("Preparing to serve a provisioner plugin...") - - server := rpc.NewServer() - packrpc.RegisterProvisioner(server, p) - - countInterrupts() - if err := serve(server); err != nil { - log.Printf("ERROR: %s", err) - os.Exit(1) - } -} - -// Tests whether or not the plugin was interrupted or not. -func Interrupted() bool { - return atomic.LoadInt32(&Interrupts) > 0 + + // Serve a single connection + log.Println("Serving a plugin connection...") + return packrpc.NewServer(conn), nil } diff --git a/packer/plugin/plugin_test.go b/packer/plugin/plugin_test.go index d80aceaf7..733190ec3 100644 --- a/packer/plugin/plugin_test.go +++ b/packer/plugin/plugin_test.go @@ -54,20 +54,50 @@ func TestHelperProcess(*testing.T) { fmt.Printf("%s1|:1234\n", APIVersion) <-make(chan int) case "builder": - ServeBuilder(new(packer.MockBuilder)) + server, err := Server() + if err != nil { + log.Printf("[ERR] %s", err) + os.Exit(1) + } + server.RegisterBuilder(new(packer.MockBuilder)) + server.Serve() case "command": - ServeCommand(new(helperCommand)) + server, err := Server() + if err != nil { + log.Printf("[ERR] %s", err) + os.Exit(1) + } + server.RegisterCommand(new(helperCommand)) + server.Serve() case "hook": - ServeHook(new(packer.MockHook)) + server, err := Server() + if err != nil { + log.Printf("[ERR] %s", err) + os.Exit(1) + } + server.RegisterHook(new(packer.MockHook)) + server.Serve() case "invalid-rpc-address": fmt.Println("lolinvalid") case "mock": fmt.Printf("%s|:1234\n", APIVersion) <-make(chan int) case "post-processor": - ServePostProcessor(new(helperPostProcessor)) + server, err := Server() + if err != nil { + log.Printf("[ERR] %s", err) + os.Exit(1) + } + server.RegisterPostProcessor(new(helperPostProcessor)) + server.Serve() case "provisioner": - ServeProvisioner(new(packer.MockProvisioner)) + server, err := Server() + if err != nil { + log.Printf("[ERR] %s", err) + os.Exit(1) + } + server.RegisterProvisioner(new(packer.MockProvisioner)) + server.Serve() case "start-timeout": time.Sleep(1 * time.Minute) os.Exit(1) diff --git a/packer/rpc/artifact.go b/packer/rpc/artifact.go index d71f8831a..c6e0a208d 100644 --- a/packer/rpc/artifact.go +++ b/packer/rpc/artifact.go @@ -8,7 +8,8 @@ import ( // An implementation of packer.Artifact where the artifact is actually // available over an RPC connection. type artifact struct { - client *rpc.Client + client *rpc.Client + endpoint string } // ArtifactServer wraps a packer.Artifact implementation and makes it @@ -17,33 +18,29 @@ type ArtifactServer struct { artifact packer.Artifact } -func Artifact(client *rpc.Client) *artifact { - return &artifact{client} -} - func (a *artifact) BuilderId() (result string) { - a.client.Call("Artifact.BuilderId", new(interface{}), &result) + a.client.Call(a.endpoint+".BuilderId", new(interface{}), &result) return } func (a *artifact) Files() (result []string) { - a.client.Call("Artifact.Files", new(interface{}), &result) + a.client.Call(a.endpoint+".Files", new(interface{}), &result) return } func (a *artifact) Id() (result string) { - a.client.Call("Artifact.Id", new(interface{}), &result) + a.client.Call(a.endpoint+".Id", new(interface{}), &result) return } func (a *artifact) String() (result string) { - a.client.Call("Artifact.String", new(interface{}), &result) + a.client.Call(a.endpoint+".String", new(interface{}), &result) return } func (a *artifact) Destroy() error { var result error - if err := a.client.Call("Artifact.Destroy", new(interface{}), &result); err != nil { + if err := a.client.Call(a.endpoint+".Destroy", new(interface{}), &result); err != nil { return err } diff --git a/packer/rpc/artifact_test.go b/packer/rpc/artifact_test.go index 336fa6d8b..37f4cef9f 100644 --- a/packer/rpc/artifact_test.go +++ b/packer/rpc/artifact_test.go @@ -2,48 +2,21 @@ package rpc import ( "github.com/mitchellh/packer/packer" - "net/rpc" "reflect" "testing" ) -type testArtifact struct{} - -func (testArtifact) BuilderId() string { - return "bid" -} - -func (testArtifact) Files() []string { - return []string{"a", "b"} -} - -func (testArtifact) Id() string { - return "id" -} - -func (testArtifact) String() string { - return "string" -} - -func (testArtifact) Destroy() error { - return nil -} - func TestArtifactRPC(t *testing.T) { // Create the interface to test - a := new(testArtifact) + a := new(packer.MockArtifact) // Start the server - server := rpc.NewServer() - RegisterArtifact(server, a) - address := serveSingleConn(server) + client, server := testClientServer(t) + defer client.Close() + defer server.Close() + server.RegisterArtifact(a) - // Create the client over RPC and run some methods to verify it works - client, err := rpc.Dial("tcp", address) - if err != nil { - t.Fatalf("err: %s", err) - } - aClient := Artifact(client) + aClient := client.Artifact() // Test if aClient.BuilderId() != "bid" { @@ -64,5 +37,5 @@ func TestArtifactRPC(t *testing.T) { } func TestArtifact_Implements(t *testing.T) { - var _ packer.Artifact = Artifact(nil) + var _ packer.Artifact = new(artifact) } diff --git a/packer/rpc/build.go b/packer/rpc/build.go index 2df03a6ab..f63f35f76 100644 --- a/packer/rpc/build.go +++ b/packer/rpc/build.go @@ -9,16 +9,14 @@ import ( // over an RPC connection. type build struct { client *rpc.Client + mux *MuxConn } // BuildServer wraps a packer.Build implementation and makes it exportable // as part of a Golang RPC server. type BuildServer struct { build packer.Build -} - -type BuildRunArgs struct { - UiRPCAddress string + mux *MuxConn } type BuildPrepareResponse struct { @@ -26,10 +24,6 @@ type BuildPrepareResponse struct { Error error } -func Build(client *rpc.Client) *build { - return &build{client} -} - func (b *build) Name() (result string) { b.client.Call("Build.Name", new(interface{}), &result) return @@ -45,25 +39,25 @@ func (b *build) Prepare(v map[string]string) ([]string, error) { } func (b *build) Run(ui packer.Ui, cache packer.Cache) ([]packer.Artifact, error) { - // Create and start the server for the UI - server := rpc.NewServer() - RegisterCache(server, cache) - RegisterUi(server, ui) - args := &BuildRunArgs{serveSingleConn(server)} + nextId := b.mux.NextId() + server := NewServerWithMux(b.mux, nextId) + server.RegisterCache(cache) + server.RegisterUi(ui) + go server.Serve() - var result []string - if err := b.client.Call("Build.Run", args, &result); err != nil { + var result []uint32 + if err := b.client.Call("Build.Run", nextId, &result); err != nil { return nil, err } artifacts := make([]packer.Artifact, len(result)) - for i, addr := range result { - client, err := rpcDial(addr) + for i, streamId := range result { + client, err := NewClientWithMux(b.mux, streamId) if err != nil { return nil, err } - artifacts[i] = Artifact(client) + artifacts[i] = client.Artifact() } return artifacts, nil @@ -101,22 +95,26 @@ func (b *BuildServer) Prepare(v map[string]string, resp *BuildPrepareResponse) e return nil } -func (b *BuildServer) Run(args *BuildRunArgs, reply *[]string) error { - client, err := rpcDial(args.UiRPCAddress) +func (b *BuildServer) Run(streamId uint32, reply *[]uint32) error { + client, err := NewClientWithMux(b.mux, streamId) if err != nil { - return err + return NewBasicError(err) } + defer client.Close() - artifacts, err := b.build.Run(&Ui{client}, Cache(client)) + artifacts, err := b.build.Run(client.Ui(), client.Cache()) if err != nil { return NewBasicError(err) } - *reply = make([]string, len(artifacts)) + *reply = make([]uint32, len(artifacts)) for i, artifact := range artifacts { - server := rpc.NewServer() - RegisterArtifact(server, artifact) - (*reply)[i] = serveSingleConn(server) + streamId := b.mux.NextId() + server := NewServerWithMux(b.mux, streamId) + server.RegisterArtifact(artifact) + go server.Serve() + + (*reply)[i] = streamId } return nil diff --git a/packer/rpc/build_test.go b/packer/rpc/build_test.go index 3cb5e1fb2..6c18ad384 100644 --- a/packer/rpc/build_test.go +++ b/packer/rpc/build_test.go @@ -3,12 +3,11 @@ package rpc import ( "errors" "github.com/mitchellh/packer/packer" - "net/rpc" "reflect" "testing" ) -var testBuildArtifact = &testArtifact{} +var testBuildArtifact = &packer.MockArtifact{} type testBuild struct { nameCalled bool @@ -60,25 +59,13 @@ func (b *testBuild) Cancel() { b.cancelCalled = true } -func buildRPCClient(t *testing.T) (*testBuild, packer.Build) { - // Create the interface to test - b := new(testBuild) - - // Start the server - server := rpc.NewServer() - RegisterBuild(server, b) - address := serveSingleConn(server) - - // Create the client over RPC and run some methods to verify it works - client, err := rpc.Dial("tcp", address) - if err != nil { - t.Fatalf("err: %s", err) - } - return b, Build(client) -} - func TestBuild(t *testing.T) { - b, bClient := buildRPCClient(t) + b := new(testBuild) + client, server := testClientServer(t) + defer client.Close() + defer server.Close() + server.RegisterBuild(b) + bClient := client.Build() // Test Name bClient.Name() @@ -120,23 +107,6 @@ func TestBuild(t *testing.T) { t.Fatalf("bad: %#v", artifacts) } - // Test the UI given to run, which should be fully functional - if b.runCalled { - b.runCache.Lock("foo") - if !cache.lockCalled { - t.Fatal("lock shuld be called") - } - - b.runUi.Say("format") - if !ui.sayCalled { - t.Fatal("say should be called") - } - - if ui.sayMessage != "format" { - t.Fatalf("bad: %#v", ui.sayMessage) - } - } - // Test run with an error b.errRunResult = true _, err = bClient.Run(ui, cache) @@ -164,7 +134,12 @@ func TestBuild(t *testing.T) { } func TestBuildPrepare_Warnings(t *testing.T) { - b, bClient := buildRPCClient(t) + b := new(testBuild) + client, server := testClientServer(t) + defer client.Close() + defer server.Close() + server.RegisterBuild(b) + bClient := client.Build() expected := []string{"foo"} b.prepareWarnings = expected @@ -179,5 +154,5 @@ func TestBuildPrepare_Warnings(t *testing.T) { } func TestBuild_ImplementsBuild(t *testing.T) { - var _ packer.Build = Build(nil) + var _ packer.Build = new(build) } diff --git a/packer/rpc/builder.go b/packer/rpc/builder.go index 42f92c4d4..7721414e2 100644 --- a/packer/rpc/builder.go +++ b/packer/rpc/builder.go @@ -1,8 +1,6 @@ package rpc import ( - "encoding/gob" - "fmt" "github.com/mitchellh/packer/packer" "log" "net/rpc" @@ -12,37 +10,25 @@ import ( // over an RPC connection. type builder struct { client *rpc.Client + mux *MuxConn } // BuilderServer wraps a packer.Builder implementation and makes it exportable // as part of a Golang RPC server. type BuilderServer struct { builder packer.Builder + mux *MuxConn } type BuilderPrepareArgs struct { Configs []interface{} } -type BuilderRunArgs struct { - RPCAddress string - ResponseAddress string -} - type BuilderPrepareResponse struct { Warnings []string Error error } -type BuilderRunResponse struct { - Err error - RPCAddress string -} - -func Builder(client *rpc.Client) *builder { - return &builder{client} -} - func (b *builder) Prepare(config ...interface{}) ([]string, error) { var resp BuilderPrepareResponse cerr := b.client.Call("Builder.Prepare", &BuilderPrepareArgs{config}, &resp) @@ -54,58 +40,28 @@ func (b *builder) Prepare(config ...interface{}) ([]string, error) { } func (b *builder) Run(ui packer.Ui, hook packer.Hook, cache packer.Cache) (packer.Artifact, error) { - // Create and start the server for the Build and UI - server := rpc.NewServer() - RegisterCache(server, cache) - RegisterHook(server, hook) - RegisterUi(server, ui) + nextId := b.mux.NextId() + server := NewServerWithMux(b.mux, nextId) + server.RegisterCache(cache) + server.RegisterHook(hook) + server.RegisterUi(ui) + go server.Serve() - // Create a server for the response - responseL := netListenerInRange(portRangeMin, portRangeMax) - runResponseCh := make(chan *BuilderRunResponse) - go func() { - defer responseL.Close() - - var response BuilderRunResponse - defer func() { runResponseCh <- &response }() - - conn, err := responseL.Accept() - if err != nil { - response.Err = err - return - } - defer conn.Close() - - decoder := gob.NewDecoder(conn) - if err := decoder.Decode(&response); err != nil { - response.Err = fmt.Errorf("Error waiting for Run: %s", err) - } - }() - - args := &BuilderRunArgs{ - serveSingleConn(server), - responseL.Addr().String(), - } - - if err := b.client.Call("Builder.Run", args, new(interface{})); err != nil { + var responseId uint32 + if err := b.client.Call("Builder.Run", nextId, &responseId); err != nil { return nil, err } - response := <-runResponseCh - if response.Err != nil { - return nil, response.Err - } - - if response.RPCAddress == "" { + if responseId == 0 { return nil, nil } - client, err := rpcDial(response.RPCAddress) + client, err := NewClientWithMux(b.mux, responseId) if err != nil { return nil, err } - return Artifact(client), nil + return client.Artifact(), nil } func (b *builder) Cancel() { @@ -127,46 +83,27 @@ func (b *BuilderServer) Prepare(args *BuilderPrepareArgs, reply *BuilderPrepareR return nil } -func (b *BuilderServer) Run(args *BuilderRunArgs, reply *interface{}) error { - client, err := rpcDial(args.RPCAddress) +func (b *BuilderServer) Run(streamId uint32, reply *uint32) error { + client, err := NewClientWithMux(b.mux, streamId) if err != nil { - return err + return NewBasicError(err) + } + defer client.Close() + + artifact, err := b.builder.Run(client.Ui(), client.Hook(), client.Cache()) + if err != nil { + return NewBasicError(err) } - responseC, err := tcpDial(args.ResponseAddress) - if err != nil { - return err + *reply = 0 + if artifact != nil { + streamId = b.mux.NextId() + server := NewServerWithMux(b.mux, streamId) + server.RegisterArtifact(artifact) + go server.Serve() + *reply = streamId } - responseWriter := gob.NewEncoder(responseC) - - // Run the build in a goroutine so we don't block the RPC connection - go func() { - defer responseC.Close() - - cache := Cache(client) - hook := Hook(client) - ui := &Ui{client} - artifact, responseErr := b.builder.Run(ui, hook, cache) - responseAddress := "" - - if responseErr == nil && artifact != nil { - // Wrap the artifact - server := rpc.NewServer() - RegisterArtifact(server, artifact) - responseAddress = serveSingleConn(server) - } - - if responseErr != nil { - responseErr = NewBasicError(responseErr) - } - - err := responseWriter.Encode(&BuilderRunResponse{responseErr, responseAddress}) - if err != nil { - log.Printf("BuildServer.Run error: %s", err) - } - }() - return nil } diff --git a/packer/rpc/builder_test.go b/packer/rpc/builder_test.go index fa20a4b85..37e8041bb 100644 --- a/packer/rpc/builder_test.go +++ b/packer/rpc/builder_test.go @@ -2,31 +2,19 @@ package rpc import ( "github.com/mitchellh/packer/packer" - "net/rpc" "reflect" "testing" ) -var testBuilderArtifact = &testArtifact{} - -func builderRPCClient(t *testing.T) (*packer.MockBuilder, packer.Builder) { - b := new(packer.MockBuilder) - - // Start the server - server := rpc.NewServer() - RegisterBuilder(server, b) - address := serveSingleConn(server) - - // Create the client over RPC and run some methods to verify it works - client, err := rpc.Dial("tcp", address) - if err != nil { - t.Fatalf("err: %s", err) - } - return b, Builder(client) -} +var testBuilderArtifact = &packer.MockArtifact{} func TestBuilderPrepare(t *testing.T) { - b, bClient := builderRPCClient(t) + b := new(packer.MockBuilder) + client, server := testClientServer(t) + defer client.Close() + defer server.Close() + server.RegisterBuilder(b) + bClient := client.Builder() // Test Prepare config := 42 @@ -48,7 +36,12 @@ func TestBuilderPrepare(t *testing.T) { } func TestBuilderPrepare_Warnings(t *testing.T) { - b, bClient := builderRPCClient(t) + b := new(packer.MockBuilder) + client, server := testClientServer(t) + defer client.Close() + defer server.Close() + server.RegisterBuilder(b) + bClient := client.Builder() expected := []string{"foo"} b.PrepareWarnings = expected @@ -64,7 +57,12 @@ func TestBuilderPrepare_Warnings(t *testing.T) { } func TestBuilderRun(t *testing.T) { - b, bClient := builderRPCClient(t) + b := new(packer.MockBuilder) + client, server := testClientServer(t) + defer client.Close() + defer server.Close() + server.RegisterBuilder(b) + bClient := client.Builder() // Test Run cache := new(testCache) @@ -79,34 +77,21 @@ func TestBuilderRun(t *testing.T) { t.Fatal("run should be called") } - b.RunCache.Lock("foo") - if !cache.lockCalled { - t.Fatal("should be called") - } - - b.RunHook.Run("foo", nil, nil, nil) - if !hook.RunCalled { - t.Fatal("should be called") - } - - b.RunUi.Say("format") - if !ui.sayCalled { - t.Fatal("say should be called") - } - - if ui.sayMessage != "format" { - t.Fatalf("bad: %s", ui.sayMessage) - } - if artifact.Id() != testBuilderArtifact.Id() { t.Fatalf("bad: %s", artifact.Id()) } } func TestBuilderRun_nilResult(t *testing.T) { - b, bClient := builderRPCClient(t) + b := new(packer.MockBuilder) b.RunNilResult = true + client, server := testClientServer(t) + defer client.Close() + defer server.Close() + server.RegisterBuilder(b) + bClient := client.Builder() + cache := new(testCache) hook := &packer.MockHook{} ui := &testUi{} @@ -120,7 +105,13 @@ func TestBuilderRun_nilResult(t *testing.T) { } func TestBuilderRun_ErrResult(t *testing.T) { - b, bClient := builderRPCClient(t) + b := new(packer.MockBuilder) + client, server := testClientServer(t) + defer client.Close() + defer server.Close() + server.RegisterBuilder(b) + bClient := client.Builder() + b.RunErrResult = true cache := new(testCache) @@ -136,7 +127,12 @@ func TestBuilderRun_ErrResult(t *testing.T) { } func TestBuilderCancel(t *testing.T) { - b, bClient := builderRPCClient(t) + b := new(packer.MockBuilder) + client, server := testClientServer(t) + defer client.Close() + defer server.Close() + server.RegisterBuilder(b) + bClient := client.Builder() bClient.Cancel() if !b.CancelCalled { @@ -145,5 +141,5 @@ func TestBuilderCancel(t *testing.T) { } func TestBuilder_ImplementsBuilder(t *testing.T) { - var _ packer.Builder = Builder(nil) + var _ packer.Builder = new(builder) } diff --git a/packer/rpc/cache.go b/packer/rpc/cache.go index 73b154cb6..184286411 100644 --- a/packer/rpc/cache.go +++ b/packer/rpc/cache.go @@ -17,10 +17,6 @@ type CacheServer struct { cache packer.Cache } -func Cache(client *rpc.Client) *cache { - return &cache{client} -} - type CacheRLockResponse struct { Path string Exists bool diff --git a/packer/rpc/cache_test.go b/packer/rpc/cache_test.go index 46cb81d08..702ca98e4 100644 --- a/packer/rpc/cache_test.go +++ b/packer/rpc/cache_test.go @@ -2,7 +2,6 @@ package rpc import ( "github.com/mitchellh/packer/packer" - "net/rpc" "testing" ) @@ -40,11 +39,7 @@ func (t *testCache) RUnlock(key string) { } func TestCache_Implements(t *testing.T) { - var raw interface{} - raw = Cache(nil) - if _, ok := raw.(packer.Cache); !ok { - t.Fatal("Cache must be a cache.") - } + var _ packer.Cache = new(cache) } func TestCacheRPC(t *testing.T) { @@ -52,19 +47,15 @@ func TestCacheRPC(t *testing.T) { c := new(testCache) // Start the server - server := rpc.NewServer() - RegisterCache(server, c) - address := serveSingleConn(server) + client, server := testClientServer(t) + defer client.Close() + defer server.Close() + server.RegisterCache(c) - // Create the client over RPC and run some methods to verify it works - rpcClient, err := rpc.Dial("tcp", address) - if err != nil { - t.Fatalf("bad: %s", err) - } - client := Cache(rpcClient) + cacheClient := client.Cache() // Test Lock - client.Lock("foo") + cacheClient.Lock("foo") if !c.lockCalled { t.Fatal("should be called") } @@ -73,7 +64,7 @@ func TestCacheRPC(t *testing.T) { } // Test Unlock - client.Unlock("foo") + cacheClient.Unlock("foo") if !c.unlockCalled { t.Fatal("should be called") } @@ -82,7 +73,7 @@ func TestCacheRPC(t *testing.T) { } // Test RLock - client.RLock("foo") + cacheClient.RLock("foo") if !c.rlockCalled { t.Fatal("should be called") } @@ -91,7 +82,7 @@ func TestCacheRPC(t *testing.T) { } // Test RUnlock - client.RUnlock("foo") + cacheClient.RUnlock("foo") if !c.runlockCalled { t.Fatal("should be called") } diff --git a/packer/rpc/client.go b/packer/rpc/client.go new file mode 100644 index 000000000..c9ddc5f57 --- /dev/null +++ b/packer/rpc/client.go @@ -0,0 +1,115 @@ +package rpc + +import ( + "github.com/mitchellh/packer/packer" + "io" + "net/rpc" +) + +// Client is the client end that communicates with a Packer RPC server. +// Establishing a connection is up to the user, the Client can just +// communicate over any ReadWriteCloser. +type Client struct { + mux *MuxConn + client *rpc.Client +} + +func NewClient(rwc io.ReadWriteCloser) (*Client, error) { + return NewClientWithMux(NewMuxConn(rwc), 0) +} + +func NewClientWithMux(mux *MuxConn, streamId uint32) (*Client, error) { + clientConn, err := mux.Dial(streamId) + if err != nil { + return nil, err + } + + return &Client{ + mux: mux, + client: rpc.NewClient(clientConn), + }, nil +} + +func (c *Client) Close() error { + if err := c.client.Close(); err != nil { + return err + } + + return nil +} + +func (c *Client) Artifact() packer.Artifact { + return &artifact{ + client: c.client, + endpoint: DefaultArtifactEndpoint, + } +} + +func (c *Client) Build() packer.Build { + return &build{ + client: c.client, + mux: c.mux, + } +} + +func (c *Client) Builder() packer.Builder { + return &builder{ + client: c.client, + mux: c.mux, + } +} + +func (c *Client) Cache() packer.Cache { + return &cache{ + client: c.client, + } +} + +func (c *Client) Command() packer.Command { + return &command{ + client: c.client, + mux: c.mux, + } +} + +func (c *Client) Communicator() packer.Communicator { + return &communicator{ + client: c.client, + mux: c.mux, + } +} + +func (c *Client) Environment() packer.Environment { + return &Environment{ + client: c.client, + mux: c.mux, + } +} + +func (c *Client) Hook() packer.Hook { + return &hook{ + client: c.client, + mux: c.mux, + } +} + +func (c *Client) PostProcessor() packer.PostProcessor { + return &postProcessor{ + client: c.client, + mux: c.mux, + } +} + +func (c *Client) Provisioner() packer.Provisioner { + return &provisioner{ + client: c.client, + mux: c.mux, + } +} + +func (c *Client) Ui() packer.Ui { + return &Ui{ + client: c.client, + endpoint: DefaultUiEndpoint, + } +} diff --git a/packer/rpc/client_test.go b/packer/rpc/client_test.go new file mode 100644 index 000000000..c0595cbe4 --- /dev/null +++ b/packer/rpc/client_test.go @@ -0,0 +1,48 @@ +package rpc + +import ( + "net" + "testing" +) + +func testConn(t *testing.T) (net.Conn, net.Conn) { + l, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("err: %s", err) + } + + var serverConn net.Conn + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) + defer l.Close() + var err error + serverConn, err = l.Accept() + if err != nil { + t.Fatalf("err: %s", err) + } + }() + + clientConn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatalf("err: %s", err) + } + <-doneCh + + return clientConn, serverConn +} + +func testClientServer(t *testing.T) (*Client, *Server) { + clientConn, serverConn := testConn(t) + + server := NewServer(serverConn) + go server.Serve() + + client, err := NewClient(clientConn) + if err != nil { + server.Close() + t.Fatalf("err: %s", err) + } + + return client, server +} diff --git a/packer/rpc/command.go b/packer/rpc/command.go index 3e2b48b2f..e0243e9bb 100644 --- a/packer/rpc/command.go +++ b/packer/rpc/command.go @@ -9,25 +9,23 @@ import ( // command is actually executed over an RPC connection. type command struct { client *rpc.Client + mux *MuxConn } // A CommandServer wraps a packer.Command and makes it exportable as part // of a Golang RPC server. type CommandServer struct { command packer.Command + mux *MuxConn } type CommandRunArgs struct { - RPCAddress string Args []string + StreamId uint32 } type CommandSynopsisArgs byte -func Command(client *rpc.Client) *command { - return &command{client} -} - func (c *command) Help() (result string) { err := c.client.Call("Command.Help", new(interface{}), &result) if err != nil { @@ -38,11 +36,15 @@ func (c *command) Help() (result string) { } func (c *command) Run(env packer.Environment, args []string) (result int) { - // Create and start the server for the Environment - server := rpc.NewServer() - RegisterEnvironment(server, env) + nextId := c.mux.NextId() + server := NewServerWithMux(c.mux, nextId) + server.RegisterEnvironment(env) + go server.Serve() - rpcArgs := &CommandRunArgs{serveSingleConn(server), args} + rpcArgs := &CommandRunArgs{ + Args: args, + StreamId: nextId, + } err := c.client.Call("Command.Run", rpcArgs, &result) if err != nil { panic(err) @@ -66,14 +68,13 @@ func (c *CommandServer) Help(args *interface{}, reply *string) error { } func (c *CommandServer) Run(args *CommandRunArgs, reply *int) error { - client, err := rpcDial(args.RPCAddress) + client, err := NewClientWithMux(c.mux, args.StreamId) if err != nil { - return err + return NewBasicError(err) } + defer client.Close() - env := &Environment{client} - - *reply = c.command.Run(env, args.Args) + *reply = c.command.Run(client.Environment(), args.Args) return nil } diff --git a/packer/rpc/command_test.go b/packer/rpc/command_test.go index 086e0eeed..7ba433956 100644 --- a/packer/rpc/command_test.go +++ b/packer/rpc/command_test.go @@ -2,7 +2,6 @@ package rpc import ( "github.com/mitchellh/packer/packer" - "net/rpc" "reflect" "testing" ) @@ -33,21 +32,14 @@ func TestRPCCommand(t *testing.T) { command := new(TestCommand) // Start the server - server := rpc.NewServer() - RegisterCommand(server, command) - address := serveSingleConn(server) - - // Create the command client over RPC and run some methods to verify - // we get the proper behavior. - client, err := rpc.Dial("tcp", address) - if err != nil { - t.Fatalf("err: %s", err) - } - - clientComm := Command(client) + client, server := testClientServer(t) + defer client.Close() + defer server.Close() + server.RegisterCommand(command) + commClient := client.Command() //Test Help - help := clientComm.Help() + help := commClient.Help() if help != "bar" { t.Fatalf("bad: %s", help) } @@ -55,7 +47,7 @@ func TestRPCCommand(t *testing.T) { // Test run runArgs := []string{"foo", "bar"} testEnv := &testEnvironment{} - exitCode := clientComm.Run(testEnv, runArgs) + exitCode := commClient.Run(testEnv, runArgs) if !reflect.DeepEqual(command.runArgs, runArgs) { t.Fatalf("bad: %#v", command.runArgs) } @@ -67,18 +59,13 @@ func TestRPCCommand(t *testing.T) { t.Fatal("runEnv should not be nil") } - command.runEnv.Ui() - if !testEnv.uiCalled { - t.Fatal("ui should be called") - } - // Test Synopsis - synopsis := clientComm.Synopsis() + synopsis := commClient.Synopsis() if synopsis != "foo" { t.Fatalf("bad: %#v", synopsis) } } func TestCommand_Implements(t *testing.T) { - var _ packer.Command = Command(nil) + var _ packer.Command = new(command) } diff --git a/packer/rpc/communicator.go b/packer/rpc/communicator.go index 77e321153..8f470075e 100644 --- a/packer/rpc/communicator.go +++ b/packer/rpc/communicator.go @@ -2,11 +2,9 @@ package rpc import ( "encoding/gob" - "errors" "github.com/mitchellh/packer/packer" "io" "log" - "net" "net/rpc" ) @@ -14,12 +12,14 @@ import ( // executed over an RPC connection. type communicator struct { client *rpc.Client + mux *MuxConn } // CommunicatorServer wraps a packer.Communicator implementation and makes // it exportable as part of a Golang RPC server. type CommunicatorServer struct { - c packer.Communicator + c packer.Communicator + mux *MuxConn } type CommandFinished struct { @@ -27,21 +27,21 @@ type CommandFinished struct { } type CommunicatorStartArgs struct { - Command string - StdinAddress string - StdoutAddress string - StderrAddress string - ResponseAddress string + Command string + StdinStreamId uint32 + StdoutStreamId uint32 + StderrStreamId uint32 + ResponseStreamId uint32 } type CommunicatorDownloadArgs struct { - Path string - WriterAddress string + Path string + WriterStreamId uint32 } type CommunicatorUploadArgs struct { - Path string - ReaderAddress string + Path string + ReaderStreamId uint32 } type CommunicatorUploadDirArgs struct { @@ -51,7 +51,7 @@ type CommunicatorUploadDirArgs struct { } func Communicator(client *rpc.Client) *communicator { - return &communicator{client} + return &communicator{client: client} } func (c *communicator) Start(cmd *packer.RemoteCmd) (err error) { @@ -59,45 +59,43 @@ func (c *communicator) Start(cmd *packer.RemoteCmd) (err error) { args.Command = cmd.Command if cmd.Stdin != nil { - stdinL := netListenerInRange(portRangeMin, portRangeMax) - args.StdinAddress = stdinL.Addr().String() - go serveSingleCopy("stdin", stdinL, nil, cmd.Stdin) + args.StdinStreamId = c.mux.NextId() + go serveSingleCopy("stdin", c.mux, args.StdinStreamId, nil, cmd.Stdin) } if cmd.Stdout != nil { - stdoutL := netListenerInRange(portRangeMin, portRangeMax) - args.StdoutAddress = stdoutL.Addr().String() - go serveSingleCopy("stdout", stdoutL, cmd.Stdout, nil) + args.StdoutStreamId = c.mux.NextId() + go serveSingleCopy("stdout", c.mux, args.StdoutStreamId, cmd.Stdout, nil) } if cmd.Stderr != nil { - stderrL := netListenerInRange(portRangeMin, portRangeMax) - args.StderrAddress = stderrL.Addr().String() - go serveSingleCopy("stderr", stderrL, cmd.Stderr, nil) + args.StderrStreamId = c.mux.NextId() + go serveSingleCopy("stderr", c.mux, args.StderrStreamId, cmd.Stderr, nil) } - responseL := netListenerInRange(portRangeMin, portRangeMax) - args.ResponseAddress = responseL.Addr().String() + responseStreamId := c.mux.NextId() + args.ResponseStreamId = responseStreamId go func() { - defer responseL.Close() - - conn, err := responseL.Accept() + conn, err := c.mux.Accept(responseStreamId) if err != nil { + log.Printf("[ERR] Error accepting response stream %d: %s", + responseStreamId, err) cmd.SetExited(123) return } - defer conn.Close() - decoder := gob.NewDecoder(conn) - var finished CommandFinished + decoder := gob.NewDecoder(conn) if err := decoder.Decode(&finished); err != nil { + log.Printf("[ERR] Error decoding response stream %d: %s", + responseStreamId, err) cmd.SetExited(123) return } + log.Printf("[INFO] RPC client: Communicator ended with: %d", finished.ExitStatus) cmd.SetExited(finished.ExitStatus) }() @@ -106,23 +104,13 @@ func (c *communicator) Start(cmd *packer.RemoteCmd) (err error) { } func (c *communicator) Upload(path string, r io.Reader) (err error) { - // We need to create a server that can proxy the reader data - // over because we can't simply gob encode an io.Reader - readerL := netListenerInRange(portRangeMin, portRangeMax) - if readerL == nil { - err = errors.New("couldn't allocate listener for upload reader") - return - } - - // Make sure at the end of this call, we close the listener - defer readerL.Close() - // Pipe the reader through to the connection - go serveSingleCopy("uploadReader", readerL, nil, r) + streamId := c.mux.NextId() + go serveSingleCopy("uploadData", c.mux, streamId, nil, r) args := CommunicatorUploadArgs{ - path, - readerL.Addr().String(), + Path: path, + ReaderStreamId: streamId, } err = c.client.Call("Communicator.Upload", &args, new(interface{})) @@ -146,99 +134,104 @@ func (c *communicator) UploadDir(dst string, src string, exclude []string) error } func (c *communicator) Download(path string, w io.Writer) (err error) { - // We need to create a server that can proxy that data downloaded - // into the writer because we can't gob encode a writer directly. - writerL := netListenerInRange(portRangeMin, portRangeMax) - if writerL == nil { - err = errors.New("couldn't allocate listener for download writer") - return - } - - // Make sure we close the listener once we're done because we'll be done - defer writerL.Close() - // Serve a single connection and a single copy - go serveSingleCopy("downloadWriter", writerL, w, nil) + streamId := c.mux.NextId() + go serveSingleCopy("downloadWriter", c.mux, streamId, w, nil) args := CommunicatorDownloadArgs{ - path, - writerL.Addr().String(), + Path: path, + WriterStreamId: streamId, } err = c.client.Call("Communicator.Download", &args, new(interface{})) return } -func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface{}) (err error) { +func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface{}) (error) { // Build the RemoteCmd on this side so that it all pipes over // to the remote side. var cmd packer.RemoteCmd cmd.Command = args.Command - toClose := make([]net.Conn, 0) - if args.StdinAddress != "" { - stdinC, err := tcpDial(args.StdinAddress) + // Create a channel to signal we're done so that we can close + // our stdin/stdout/stderr streams + toClose := make([]io.Closer, 0) + doneCh := make(chan struct{}) + go func() { + <-doneCh + for _, conn := range toClose { + defer conn.Close() + } + }() + + if args.StdinStreamId > 0 { + conn, err := c.mux.Dial(args.StdinStreamId) if err != nil { - return err + close(doneCh) + return NewBasicError(err) } - toClose = append(toClose, stdinC) - cmd.Stdin = stdinC + toClose = append(toClose, conn) + cmd.Stdin = conn } - if args.StdoutAddress != "" { - stdoutC, err := tcpDial(args.StdoutAddress) + if args.StdoutStreamId > 0 { + conn, err := c.mux.Dial(args.StdoutStreamId) if err != nil { - return err + close(doneCh) + return NewBasicError(err) } - toClose = append(toClose, stdoutC) - cmd.Stdout = stdoutC + toClose = append(toClose, conn) + cmd.Stdout = conn } - if args.StderrAddress != "" { - stderrC, err := tcpDial(args.StderrAddress) + if args.StderrStreamId > 0 { + conn, err := c.mux.Dial(args.StderrStreamId) if err != nil { - return err + close(doneCh) + return NewBasicError(err) } - toClose = append(toClose, stderrC) - cmd.Stderr = stderrC + toClose = append(toClose, conn) + cmd.Stderr = conn } + // Connect to the response address so we can write our result to it // when ready. - responseC, err := tcpDial(args.ResponseAddress) + responseC, err := c.mux.Dial(args.ResponseStreamId) if err != nil { - return err + close(doneCh) + return NewBasicError(err) } - responseWriter := gob.NewEncoder(responseC) // Start the actual command err = c.c.Start(&cmd) + if err != nil { + close(doneCh) + return NewBasicError(err) + } // Start a goroutine to spin and wait for the process to actual // exit. When it does, report it back to caller... go func() { + defer close(doneCh) defer responseC.Close() - for _, conn := range toClose { - defer conn.Close() - } - cmd.Wait() + log.Printf("[INFO] RPC endpoint: Communicator ended with: %d", cmd.ExitStatus) responseWriter.Encode(&CommandFinished{cmd.ExitStatus}) }() - return + return nil } func (c *CommunicatorServer) Upload(args *CommunicatorUploadArgs, reply *interface{}) (err error) { - readerC, err := tcpDial(args.ReaderAddress) + readerC, err := c.mux.Dial(args.ReaderStreamId) if err != nil { return } - defer readerC.Close() err = c.c.Upload(args.Path, readerC) @@ -250,21 +243,18 @@ func (c *CommunicatorServer) UploadDir(args *CommunicatorUploadDirArgs, reply *e } func (c *CommunicatorServer) Download(args *CommunicatorDownloadArgs, reply *interface{}) (err error) { - writerC, err := tcpDial(args.WriterAddress) + writerC, err := c.mux.Dial(args.WriterStreamId) if err != nil { return } - defer writerC.Close() err = c.c.Download(args.Path, writerC) return } -func serveSingleCopy(name string, l net.Listener, dst io.Writer, src io.Reader) { - defer l.Close() - - conn, err := l.Accept() +func serveSingleCopy(name string, mux *MuxConn, id uint32, dst io.Writer, src io.Reader) { + conn, err := mux.Accept(id) if err != nil { log.Printf("'%s' accept error: %s", name, err) return diff --git a/packer/rpc/communicator_test.go b/packer/rpc/communicator_test.go index f6013a061..ca8239514 100644 --- a/packer/rpc/communicator_test.go +++ b/packer/rpc/communicator_test.go @@ -4,7 +4,6 @@ import ( "bufio" "github.com/mitchellh/packer/packer" "io" - "net/rpc" "reflect" "testing" ) @@ -14,16 +13,11 @@ func TestCommunicatorRPC(t *testing.T) { c := new(packer.MockCommunicator) // Start the server - server := rpc.NewServer() - RegisterCommunicator(server, c) - address := serveSingleConn(server) - - // Create the client over RPC and run some methods to verify it works - client, err := rpc.Dial("tcp", address) - if err != nil { - t.Fatalf("err: %s", err) - } - remote := Communicator(client) + client, server := testClientServer(t) + defer client.Close() + defer server.Close() + server.RegisterCommunicator(c) + remote := client.Communicator() // The remote command we'll use stdin_r, stdin_w := io.Pipe() @@ -42,7 +36,7 @@ func TestCommunicatorRPC(t *testing.T) { c.StartExitStatus = 42 // Test Start - err = remote.Start(&cmd) + err := remote.Start(&cmd) if err != nil { t.Fatalf("err: %s", err) } @@ -74,7 +68,7 @@ func TestCommunicatorRPC(t *testing.T) { stdin_w.Close() cmd.Wait() if c.StartStdin != "info\n" { - t.Fatalf("bad data: %s", data) + t.Fatalf("bad data: %s", c.StartStdin) } // Test that we can get the exit status properly diff --git a/packer/rpc/environment.go b/packer/rpc/environment.go index 36db72c56..1aa3b9abb 100644 --- a/packer/rpc/environment.go +++ b/packer/rpc/environment.go @@ -2,6 +2,7 @@ package rpc import ( "github.com/mitchellh/packer/packer" + "log" "net/rpc" ) @@ -9,12 +10,14 @@ import ( // where the actual environment is executed over an RPC connection. type Environment struct { client *rpc.Client + mux *MuxConn } // A EnvironmentServer wraps a packer.Environment and makes it exportable // as part of a Golang RPC server. type EnvironmentServer struct { env packer.Environment + mux *MuxConn } type EnvironmentCliArgs struct { @@ -22,33 +25,32 @@ type EnvironmentCliArgs struct { } func (e *Environment) Builder(name string) (b packer.Builder, err error) { - var reply string - err = e.client.Call("Environment.Builder", name, &reply) + var streamId uint32 + err = e.client.Call("Environment.Builder", name, &streamId) if err != nil { return } - client, err := rpcDial(reply) + client, err := NewClientWithMux(e.mux, streamId) if err != nil { - return + return nil, err } - - b = Builder(client) + b = client.Builder() return } func (e *Environment) Cache() packer.Cache { - var reply string - if err := e.client.Call("Environment.Cache", new(interface{}), &reply); err != nil { + var streamId uint32 + if err := e.client.Call("Environment.Cache", new(interface{}), &streamId); err != nil { panic(err) } - client, err := rpcDial(reply) + client, err := NewClientWithMux(e.mux, streamId) if err != nil { - panic(err) + log.Printf("[ERR] Error getting cache client: %s", err) + return nil } - - return Cache(client) + return client.Cache() } func (e *Environment) Cli(args []string) (result int, err error) { @@ -58,85 +60,81 @@ func (e *Environment) Cli(args []string) (result int, err error) { } func (e *Environment) Hook(name string) (h packer.Hook, err error) { - var reply string - err = e.client.Call("Environment.Hook", name, &reply) + var streamId uint32 + err = e.client.Call("Environment.Hook", name, &streamId) if err != nil { return } - client, err := rpcDial(reply) + client, err := NewClientWithMux(e.mux, streamId) if err != nil { - return + return nil, err } - - h = Hook(client) - return + return client.Hook(), nil } func (e *Environment) PostProcessor(name string) (p packer.PostProcessor, err error) { - var reply string - err = e.client.Call("Environment.PostProcessor", name, &reply) + var streamId uint32 + err = e.client.Call("Environment.PostProcessor", name, &streamId) if err != nil { return } - client, err := rpcDial(reply) + client, err := NewClientWithMux(e.mux, streamId) if err != nil { - return + return nil, err } - - p = PostProcessor(client) + p = client.PostProcessor() return } func (e *Environment) Provisioner(name string) (p packer.Provisioner, err error) { - var reply string - err = e.client.Call("Environment.Provisioner", name, &reply) + var streamId uint32 + err = e.client.Call("Environment.Provisioner", name, &streamId) if err != nil { return } - client, err := rpcDial(reply) + client, err := NewClientWithMux(e.mux, streamId) if err != nil { - return + return nil, err } - - p = Provisioner(client) + p = client.Provisioner() return } func (e *Environment) Ui() packer.Ui { - var reply string - e.client.Call("Environment.Ui", new(interface{}), &reply) + var streamId uint32 + e.client.Call("Environment.Ui", new(interface{}), &streamId) - client, err := rpcDial(reply) + client, err := NewClientWithMux(e.mux, streamId) if err != nil { - panic(err) + log.Printf("[ERR] Error connecting to Ui: %s", err) + return nil } - - return &Ui{client} + return client.Ui() } -func (e *EnvironmentServer) Builder(name *string, reply *string) error { - builder, err := e.env.Builder(*name) +func (e *EnvironmentServer) Builder(name string, reply *uint32) error { + builder, err := e.env.Builder(name) if err != nil { - return err + return NewBasicError(err) } - // Wrap it - server := rpc.NewServer() - RegisterBuilder(server, builder) - - *reply = serveSingleConn(server) + *reply = e.mux.NextId() + server := NewServerWithMux(e.mux, *reply) + server.RegisterBuilder(builder) + go server.Serve() return nil } -func (e *EnvironmentServer) Cache(args *interface{}, reply *string) error { +func (e *EnvironmentServer) Cache(args *interface{}, reply *uint32) error { cache := e.env.Cache() - server := rpc.NewServer() - RegisterCache(server, cache) - *reply = serveSingleConn(server) + *reply = e.mux.NextId() + server := NewServerWithMux(e.mux, *reply) + server.RegisterCache(cache) + go server.Serve() return nil } @@ -145,53 +143,51 @@ func (e *EnvironmentServer) Cli(args *EnvironmentCliArgs, reply *int) (err error return } -func (e *EnvironmentServer) Hook(name *string, reply *string) error { - hook, err := e.env.Hook(*name) +func (e *EnvironmentServer) Hook(name string, reply *uint32) error { + hook, err := e.env.Hook(name) if err != nil { - return err + return NewBasicError(err) } - // Wrap it - server := rpc.NewServer() - RegisterHook(server, hook) - - *reply = serveSingleConn(server) + *reply = e.mux.NextId() + server := NewServerWithMux(e.mux, *reply) + server.RegisterHook(hook) + go server.Serve() return nil } -func (e *EnvironmentServer) PostProcessor(name *string, reply *string) error { - pp, err := e.env.PostProcessor(*name) +func (e *EnvironmentServer) PostProcessor(name string, reply *uint32) error { + pp, err := e.env.PostProcessor(name) if err != nil { - return err + return NewBasicError(err) } - server := rpc.NewServer() - RegisterPostProcessor(server, pp) - - *reply = serveSingleConn(server) + *reply = e.mux.NextId() + server := NewServerWithMux(e.mux, *reply) + server.RegisterPostProcessor(pp) + go server.Serve() return nil } -func (e *EnvironmentServer) Provisioner(name *string, reply *string) error { - prov, err := e.env.Provisioner(*name) +func (e *EnvironmentServer) Provisioner(name string, reply *uint32) error { + prov, err := e.env.Provisioner(name) if err != nil { - return err + return NewBasicError(err) } - server := rpc.NewServer() - RegisterProvisioner(server, prov) - - *reply = serveSingleConn(server) + *reply = e.mux.NextId() + server := NewServerWithMux(e.mux, *reply) + server.RegisterProvisioner(prov) + go server.Serve() return nil } -func (e *EnvironmentServer) Ui(args *interface{}, reply *string) error { +func (e *EnvironmentServer) Ui(args *interface{}, reply *uint32) error { ui := e.env.Ui() - // Wrap it - server := rpc.NewServer() - RegisterUi(server, ui) - - *reply = serveSingleConn(server) + *reply = e.mux.NextId() + server := NewServerWithMux(e.mux, *reply) + server.RegisterUi(ui) + go server.Serve() return nil } diff --git a/packer/rpc/environment_test.go b/packer/rpc/environment_test.go index cb80929ec..bd0a6784a 100644 --- a/packer/rpc/environment_test.go +++ b/packer/rpc/environment_test.go @@ -2,7 +2,6 @@ package rpc import ( "github.com/mitchellh/packer/packer" - "net/rpc" "reflect" "testing" ) @@ -69,16 +68,11 @@ func TestEnvironmentRPC(t *testing.T) { e := &testEnvironment{} // Start the server - server := rpc.NewServer() - RegisterEnvironment(server, e) - address := serveSingleConn(server) - - // Create the client over RPC and run some methods to verify it works - client, err := rpc.Dial("tcp", address) - if err != nil { - t.Fatalf("err: %s", err) - } - eClient := &Environment{client} + client, server := testClientServer(t) + defer client.Close() + defer server.Close() + server.RegisterEnvironment(e) + eClient := client.Environment() // Test Builder builder, _ := eClient.Builder("foo") diff --git a/packer/rpc/hook.go b/packer/rpc/hook.go index 223b96df2..18682fac6 100644 --- a/packer/rpc/hook.go +++ b/packer/rpc/hook.go @@ -10,32 +10,36 @@ import ( // over an RPC connection. type hook struct { client *rpc.Client + mux *MuxConn } // HookServer wraps a packer.Hook implementation and makes it exportable // as part of a Golang RPC server. type HookServer struct { hook packer.Hook + mux *MuxConn } type HookRunArgs struct { - Name string - Data interface{} - RPCAddress string -} - -func Hook(client *rpc.Client) *hook { - return &hook{client} + Name string + Data interface{} + StreamId uint32 } func (h *hook) Run(name string, ui packer.Ui, comm packer.Communicator, data interface{}) error { - server := rpc.NewServer() - RegisterCommunicator(server, comm) - RegisterUi(server, ui) - address := serveSingleConn(server) + nextId := h.mux.NextId() + server := NewServerWithMux(h.mux, nextId) + server.RegisterCommunicator(comm) + server.RegisterUi(ui) + go server.Serve() - args := &HookRunArgs{name, data, address} - return h.client.Call("Hook.Run", args, new(interface{})) + args := HookRunArgs{ + Name: name, + Data: data, + StreamId: nextId, + } + + return h.client.Call("Hook.Run", &args, new(interface{})) } func (h *hook) Cancel() { @@ -46,12 +50,13 @@ func (h *hook) Cancel() { } func (h *HookServer) Run(args *HookRunArgs, reply *interface{}) error { - client, err := rpcDial(args.RPCAddress) + client, err := NewClientWithMux(h.mux, args.StreamId) if err != nil { - return err + return NewBasicError(err) } + defer client.Close() - if err := h.hook.Run(args.Name, &Ui{client}, Communicator(client), args.Data); err != nil { + if err := h.hook.Run(args.Name, client.Ui(), client.Communicator(), args.Data); err != nil { return NewBasicError(err) } diff --git a/packer/rpc/hook_test.go b/packer/rpc/hook_test.go index c7ffc7258..b3f4a420c 100644 --- a/packer/rpc/hook_test.go +++ b/packer/rpc/hook_test.go @@ -2,7 +2,6 @@ package rpc import ( "github.com/mitchellh/packer/packer" - "net/rpc" "reflect" "sync" "testing" @@ -14,17 +13,11 @@ func TestHookRPC(t *testing.T) { h := new(packer.MockHook) // Serve - server := rpc.NewServer() - RegisterHook(server, h) - address := serveSingleConn(server) - - // Create the client over RPC and run some methods to verify it works - client, err := rpc.Dial("tcp", address) - if err != nil { - t.Fatalf("err: %s", err) - } - - hClient := Hook(client) + client, server := testClientServer(t) + defer client.Close() + defer server.Close() + server.RegisterHook(h) + hClient := client.Hook() // Test Run ui := &testUi{} @@ -60,17 +53,11 @@ func TestHook_cancelWhileRun(t *testing.T) { } // Serve - server := rpc.NewServer() - RegisterHook(server, h) - address := serveSingleConn(server) - - // Create the client over RPC and run some methods to verify it works - client, err := rpc.Dial("tcp", address) - if err != nil { - t.Fatalf("err: %s", err) - } - - hClient := Hook(client) + client, server := testClientServer(t) + defer client.Close() + defer server.Close() + server.RegisterHook(h) + hClient := client.Hook() // Start the run finished := make(chan struct{}) diff --git a/packer/rpc/muxconn.go b/packer/rpc/muxconn.go new file mode 100644 index 000000000..9201dfaa7 --- /dev/null +++ b/packer/rpc/muxconn.go @@ -0,0 +1,447 @@ +package rpc + +import ( + "encoding/binary" + "fmt" + "io" + "log" + "sync" + "time" +) + +// MuxConn is a connection that can be used bi-directionally for RPC. Normally, +// Go RPC only allows client-to-server connections. This allows the client +// to actually act as a server as well. +// +// MuxConn works using a fairly dumb multiplexing technique of simply +// framing every piece of data sent into a prefix + data format. Streams +// are established using a subset of the TCP protocol. Only a subset is +// necessary since we assume ordering on the underlying RWC. +type MuxConn struct { + curId uint32 + rwc io.ReadWriteCloser + streams map[uint32]*Stream + mu sync.RWMutex + wlock sync.Mutex +} + +type muxPacketType byte + +const ( + muxPacketSyn muxPacketType = iota + muxPacketAck + muxPacketFin + muxPacketData +) + +func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn { + m := &MuxConn{ + rwc: rwc, + streams: make(map[uint32]*Stream), + } + + go m.loop() + + return m +} + +// Close closes the underlying io.ReadWriteCloser. This will also close +// all streams that are open. +func (m *MuxConn) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + + // Close all the streams + for _, w := range m.streams { + w.Close() + } + m.streams = make(map[uint32]*Stream) + + return m.rwc.Close() +} + +// Accept accepts a multiplexed connection with the given ID. This +// will block until a request is made to connect. +func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) { + stream, err := m.openStream(id) + if err != nil { + return nil, err + } + + // If the stream isn't closed, then it is already open somehow + stream.mu.Lock() + if stream.state != streamStateSynRecv && stream.state != streamStateClosed { + stream.mu.Unlock() + return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state) + } + + if stream.state == streamStateSynRecv { + // Fast track establishing since we already got the syn + stream.setState(streamStateEstablished) + stream.mu.Unlock() + } + + if stream.state != streamStateEstablished { + // Go into the listening state + stream.setState(streamStateListen) + + // Register a state change listener to wait for changes + stateCh := make(chan streamState, 10) + stream.registerStateListener(stateCh) + defer func() { + stream.mu.Lock() + defer stream.mu.Unlock() + stream.deregisterStateListener(stateCh) + }() + + stream.mu.Unlock() + + // Wait for the connection to establish + ACCEPT_ESTABLISH_LOOP: + for { + state := <-stateCh + switch state { + case streamStateListen: + case streamStateEstablished: + break ACCEPT_ESTABLISH_LOOP + default: + defer stream.mu.Unlock() + return nil, fmt.Errorf("Stream %d went to bad state: %d", id, stream.state) + } + } + } + + // Send the ack down + if _, err := m.write(stream.id, muxPacketAck, nil); err != nil { + return nil, err + } + + return stream, nil +} + +// Dial opens a connection to the remote end using the given stream ID. +// An Accept on the remote end will only work with if the IDs match. +func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) { + stream, err := m.openStream(id) + if err != nil { + return nil, err + } + + // If the stream isn't closed, then it is already open somehow + stream.mu.Lock() + if stream.state != streamStateClosed { + stream.mu.Unlock() + return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state) + } + + // Open a connection + if _, err := m.write(stream.id, muxPacketSyn, nil); err != nil { + return nil, err + } + stream.setState(streamStateSynSent) + + // Register a state change listener to wait for changes + stateCh := make(chan streamState, 10) + stream.registerStateListener(stateCh) + defer func() { + stream.mu.Lock() + defer stream.mu.Unlock() + stream.deregisterStateListener(stateCh) + }() + + stream.mu.Unlock() + + for { + state := <-stateCh + switch state { + case streamStateSynSent: + case streamStateEstablished: + return stream, nil + default: + defer stream.mu.Unlock() + return nil, fmt.Errorf("Stream %d went to bad state: %d", id, stream.state) + } + } +} + +// NextId returns the next available stream ID that isn't currently +// taken. +func (m *MuxConn) NextId() uint32 { + m.mu.Lock() + defer m.mu.Unlock() + + for { + result := m.curId + m.curId++ + if _, ok := m.streams[result]; !ok { + return result + } + } +} + +func (m *MuxConn) openStream(id uint32) (*Stream, error) { + // First grab a read-lock if we have the stream already we can + // cheaply return it. + m.mu.RLock() + if stream, ok := m.streams[id]; ok { + m.mu.RUnlock() + return stream, nil + } + + // Now acquire a full blown write lock so we can create the stream + m.mu.RUnlock() + m.mu.Lock() + defer m.mu.Unlock() + + // We have to check this again because there is a time period + // above where we couldn't lost this lock. + if stream, ok := m.streams[id]; ok { + return stream, nil + } + + // Create the stream object and channel where data will be sent to + dataR, dataW := io.Pipe() + writeCh := make(chan []byte, 256) + + // Set the data channel so we can write to it. + stream := &Stream{ + id: id, + mux: m, + reader: dataR, + writeCh: writeCh, + stateChange: make(map[chan<- streamState]struct{}), + } + stream.setState(streamStateClosed) + + // Start the goroutine that will read from the queue and write + // data out. + go func() { + defer dataW.Close() + + for { + data := <-writeCh + if data == nil { + // A nil is a tombstone letting us know we're done + // accepting data. + return + } + + if _, err := dataW.Write(data); err != nil { + return + } + } + }() + + m.streams[id] = stream + return m.streams[id], nil +} + +func (m *MuxConn) loop() { + defer func() { + m.mu.Lock() + defer m.mu.Unlock() + for _, w := range m.streams { + w.mu.Lock() + w.remoteClose() + w.mu.Unlock() + } + }() + + var id uint32 + var packetType muxPacketType + var length int32 + for { + if err := binary.Read(m.rwc, binary.BigEndian, &id); err != nil { + log.Printf("[ERR] Error reading stream ID: %s", err) + return + } + if err := binary.Read(m.rwc, binary.BigEndian, &packetType); err != nil { + log.Printf("[ERR] Error reading packet type: %s", err) + return + } + if err := binary.Read(m.rwc, binary.BigEndian, &length); err != nil { + log.Printf("[ERR] Error reading length: %s", err) + return + } + + // TODO(mitchellh): probably would be better to re-use a buffer... + data := make([]byte, length) + if length > 0 { + if _, err := m.rwc.Read(data); err != nil { + log.Printf("[ERR] Error reading data: %s", err) + return + } + } + + stream, err := m.openStream(id) + if err != nil { + log.Printf("[ERR] Error opening stream %d: %s", id, err) + return + } + + //log.Printf("[DEBUG] Stream %d received packet %d", id, packetType) + switch packetType { + case muxPacketAck: + stream.mu.Lock() + switch stream.state { + case streamStateSynSent: + stream.setState(streamStateEstablished) + case streamStateFinWait1: + stream.setState(streamStateFinWait2) + default: + log.Printf("[ERR] Ack received for stream in state: %d", stream.state) + } + stream.mu.Unlock() + case muxPacketSyn: + stream.mu.Lock() + switch stream.state { + case streamStateClosed: + stream.setState(streamStateSynRecv) + case streamStateListen: + stream.setState(streamStateEstablished) + default: + log.Printf("[ERR] Syn received for stream in state: %d", stream.state) + } + stream.mu.Unlock() + case muxPacketFin: + stream.mu.Lock() + switch stream.state { + case streamStateEstablished: + stream.setState(streamStateCloseWait) + m.write(id, muxPacketAck, nil) + + // Close the writer on our end since we won't receive any + // more data. + stream.writeCh <- nil + case streamStateFinWait1: + fallthrough + case streamStateFinWait2: + stream.remoteClose() + + // Remove this stream from being active so that it + // can be re-used + m.mu.Lock() + delete(m.streams, stream.id) + m.mu.Unlock() + default: + log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state) + } + stream.mu.Unlock() + + case muxPacketData: + stream.mu.Lock() + if stream.state == streamStateEstablished { + select { + case stream.writeCh <- data: + default: + panic(fmt.Sprintf("Failed to write data, buffer full for stream %d", id)) + } + } else { + log.Printf("[ERR] Data received for stream in state: %d", stream.state) + } + stream.mu.Unlock() + } + } +} + +func (m *MuxConn) write(id uint32, dataType muxPacketType, p []byte) (int, error) { + m.wlock.Lock() + defer m.wlock.Unlock() + + if err := binary.Write(m.rwc, binary.BigEndian, id); err != nil { + return 0, err + } + if err := binary.Write(m.rwc, binary.BigEndian, byte(dataType)); err != nil { + return 0, err + } + if err := binary.Write(m.rwc, binary.BigEndian, int32(len(p))); err != nil { + return 0, err + } + if len(p) == 0 { + return 0, nil + } + return m.rwc.Write(p) +} + +// Stream is a single stream of data and implements io.ReadWriteCloser +type Stream struct { + id uint32 + mux *MuxConn + reader io.Reader + state streamState + stateChange map[chan<- streamState]struct{} + stateUpdated time.Time + mu sync.Mutex + writeCh chan<- []byte +} + +type streamState byte + +const ( + streamStateClosed streamState = iota + streamStateListen + streamStateSynRecv + streamStateSynSent + streamStateEstablished + streamStateFinWait1 + streamStateFinWait2 + streamStateCloseWait +) + +func (s *Stream) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.state != streamStateEstablished && s.state != streamStateCloseWait { + return fmt.Errorf("Stream in bad state: %d", s.state) + } + + if s.state == streamStateEstablished { + s.setState(streamStateFinWait1) + } else { + s.remoteClose() + } + + s.mux.write(s.id, muxPacketFin, nil) + return nil +} + +func (s *Stream) Read(p []byte) (int, error) { + return s.reader.Read(p) +} + +func (s *Stream) Write(p []byte) (int, error) { + s.mu.Lock() + state := s.state + s.mu.Unlock() + + if state != streamStateEstablished { + return 0, fmt.Errorf("Stream %d in bad state to send: %d", s.id, state) + } + + return s.mux.write(s.id, muxPacketData, p) +} + +func (s *Stream) remoteClose() { + s.setState(streamStateClosed) + s.writeCh <- nil +} + +func (s *Stream) registerStateListener(ch chan<- streamState) { + s.stateChange[ch] = struct{}{} +} + +func (s *Stream) deregisterStateListener(ch chan<- streamState) { + delete(s.stateChange, ch) +} + +func (s *Stream) setState(state streamState) { + s.state = state + s.stateUpdated = time.Now().UTC() + for ch, _ := range s.stateChange { + select { + case ch <- state: + default: + } + } +} diff --git a/packer/rpc/muxconn_test.go b/packer/rpc/muxconn_test.go new file mode 100644 index 000000000..f4ba59db1 --- /dev/null +++ b/packer/rpc/muxconn_test.go @@ -0,0 +1,202 @@ +package rpc + +import ( + "io" + "net" + "sync" + "testing" +) + +func readStream(t *testing.T, s io.Reader) string { + var data [1024]byte + n, err := s.Read(data[:]) + if err != nil { + t.Fatalf("err: %s", err) + } + + return string(data[0:n]) +} + +func testMux(t *testing.T) (client *MuxConn, server *MuxConn) { + l, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("err: %s", err) + } + + // Server side + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) + conn, err := l.Accept() + l.Close() + if err != nil { + t.Fatalf("err: %s", err) + } + + server = NewMuxConn(conn) + }() + + // Client side + conn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatalf("err: %s", err) + } + client = NewMuxConn(conn) + + // Wait for the server + <-doneCh + + return +} + +func TestMuxConn(t *testing.T) { + client, server := testMux(t) + defer client.Close() + defer server.Close() + + // When the server is done + doneCh := make(chan struct{}) + + // The server side + go func() { + defer close(doneCh) + + s0, err := server.Accept(0) + if err != nil { + t.Fatalf("err: %s", err) + } + + s1, err := server.Dial(1) + if err != nil { + t.Fatalf("err: %s", err) + } + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + data := readStream(t, s1) + if data != "another" { + t.Fatalf("bad: %#v", data) + } + }() + + go func() { + defer wg.Done() + data := readStream(t, s0) + if data != "hello" { + t.Fatalf("bad: %#v", data) + } + }() + + wg.Wait() + }() + + s0, err := client.Dial(0) + if err != nil { + t.Fatalf("err: %s", err) + } + + s1, err := client.Accept(1) + if err != nil { + t.Fatalf("err: %s", err) + } + + if _, err := s0.Write([]byte("hello")); err != nil { + t.Fatalf("err: %s", err) + } + if _, err := s1.Write([]byte("another")); err != nil { + t.Fatalf("err: %s", err) + } + + // Wait for the server to be done + <-doneCh +} + +func TestMuxConn_socketClose(t *testing.T) { + client, server := testMux(t) + defer client.Close() + defer server.Close() + + go func() { + _, err := server.Accept(0) + if err != nil { + t.Fatalf("err: %s", err) + } + + server.rwc.Close() + }() + + s0, err := client.Dial(0) + if err != nil { + t.Fatalf("err: %s", err) + } + + var data [1024]byte + _, err = s0.Read(data[:]) + if err != io.EOF { + t.Fatalf("err: %s", err) + } +} + +func TestMuxConn_clientClosesStreams(t *testing.T) { + client, server := testMux(t) + defer client.Close() + defer server.Close() + + go func() { + conn, err := server.Accept(0) + if err != nil { + t.Fatalf("err: %s", err) + } + conn.Close() + }() + + s0, err := client.Dial(0) + if err != nil { + t.Fatalf("err: %s", err) + } + + var data [1024]byte + _, err = s0.Read(data[:]) + if err != io.EOF { + t.Fatalf("err: %s", err) + } +} + +func TestMuxConn_serverClosesStreams(t *testing.T) { + client, server := testMux(t) + defer client.Close() + defer server.Close() + go server.Accept(0) + + s0, err := client.Dial(0) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := server.Close(); err != nil { + t.Fatalf("err: %s", err) + } + + // This should block forever since we never write onto this stream. + var data [1024]byte + _, err = s0.Read(data[:]) + if err != io.EOF { + t.Fatalf("err: %s", err) + } +} + +func TestMuxConnNextId(t *testing.T) { + client, server := testMux(t) + defer client.Close() + defer server.Close() + + a := client.NextId() + b := client.NextId() + + if a != 0 || b != 1 { + t.Fatalf("IDs should increment") + } +} diff --git a/packer/rpc/post_processor.go b/packer/rpc/post_processor.go index 0a5eaefd3..ad8fd0b9d 100644 --- a/packer/rpc/post_processor.go +++ b/packer/rpc/post_processor.go @@ -9,12 +9,15 @@ import ( // executed over an RPC connection. type postProcessor struct { client *rpc.Client + mux *MuxConn } // PostProcessorServer wraps a packer.PostProcessor implementation and makes it // exportable as part of a Golang RPC server. type PostProcessorServer struct { - p packer.PostProcessor + client *rpc.Client + mux *MuxConn + p packer.PostProcessor } type PostProcessorConfigureArgs struct { @@ -22,14 +25,11 @@ type PostProcessorConfigureArgs struct { } type PostProcessorProcessResponse struct { - Err error - Keep bool - RPCAddress string + Err error + Keep bool + StreamId uint32 } -func PostProcessor(client *rpc.Client) *postProcessor { - return &postProcessor{client} -} func (p *postProcessor) Configure(raw ...interface{}) (err error) { args := &PostProcessorConfigureArgs{Configs: raw} if cerr := p.client.Call("PostProcessor.Configure", args, &err); cerr != nil { @@ -40,12 +40,14 @@ func (p *postProcessor) Configure(raw ...interface{}) (err error) { } func (p *postProcessor) PostProcess(ui packer.Ui, a packer.Artifact) (packer.Artifact, bool, error) { - server := rpc.NewServer() - RegisterArtifact(server, a) - RegisterUi(server, ui) + nextId := p.mux.NextId() + server := NewServerWithMux(p.mux, nextId) + server.RegisterArtifact(a) + server.RegisterUi(ui) + go server.Serve() var response PostProcessorProcessResponse - if err := p.client.Call("PostProcessor.PostProcess", serveSingleConn(server), &response); err != nil { + if err := p.client.Call("PostProcessor.PostProcess", nextId, &response); err != nil { return nil, false, err } @@ -53,16 +55,16 @@ func (p *postProcessor) PostProcess(ui packer.Ui, a packer.Artifact) (packer.Art return nil, false, response.Err } - if response.RPCAddress == "" { + if response.StreamId == 0 { return nil, false, nil } - client, err := rpcDial(response.RPCAddress) + client, err := NewClientWithMux(p.mux, response.StreamId) if err != nil { return nil, false, err } - return Artifact(client), response.Keep, nil + return client.Artifact(), response.Keep, nil } func (p *PostProcessorServer) Configure(args *PostProcessorConfigureArgs, reply *error) error { @@ -74,19 +76,20 @@ func (p *PostProcessorServer) Configure(args *PostProcessorConfigureArgs, reply return nil } -func (p *PostProcessorServer) PostProcess(address string, reply *PostProcessorProcessResponse) error { - client, err := rpcDial(address) +func (p *PostProcessorServer) PostProcess(streamId uint32, reply *PostProcessorProcessResponse) error { + client, err := NewClientWithMux(p.mux, streamId) if err != nil { - return err + return NewBasicError(err) } + defer client.Close() - responseAddress := "" - - artifact, keep, err := p.p.PostProcess(&Ui{client}, Artifact(client)) - if err == nil && artifact != nil { - server := rpc.NewServer() - RegisterArtifact(server, artifact) - responseAddress = serveSingleConn(server) + streamId = 0 + artifactResult, keep, err := p.p.PostProcess(client.Ui(), client.Artifact()) + if err == nil && artifactResult != nil { + streamId = p.mux.NextId() + server := NewServerWithMux(p.mux, streamId) + server.RegisterArtifact(artifactResult) + go server.Serve() } if err != nil { @@ -94,9 +97,9 @@ func (p *PostProcessorServer) PostProcess(address string, reply *PostProcessorPr } *reply = PostProcessorProcessResponse{ - Err: err, - Keep: keep, - RPCAddress: responseAddress, + Err: err, + Keep: keep, + StreamId: streamId, } return nil diff --git a/packer/rpc/post_processor_test.go b/packer/rpc/post_processor_test.go index 75f3deca0..dabd8c03a 100644 --- a/packer/rpc/post_processor_test.go +++ b/packer/rpc/post_processor_test.go @@ -2,18 +2,18 @@ package rpc import ( "github.com/mitchellh/packer/packer" - "net/rpc" "reflect" "testing" ) -var testPostProcessorArtifact = new(testArtifact) +var testPostProcessorArtifact = new(packer.MockArtifact) type TestPostProcessor struct { configCalled bool configVal []interface{} ppCalled bool ppArtifact packer.Artifact + ppArtifactId string ppUi packer.Ui } @@ -26,6 +26,7 @@ func (pp *TestPostProcessor) Configure(v ...interface{}) error { func (pp *TestPostProcessor) PostProcess(ui packer.Ui, a packer.Artifact) (packer.Artifact, bool, error) { pp.ppCalled = true pp.ppArtifact = a + pp.ppArtifactId = a.Id() pp.ppUi = ui return testPostProcessorArtifact, false, nil } @@ -35,20 +36,16 @@ func TestPostProcessorRPC(t *testing.T) { p := new(TestPostProcessor) // Start the server - server := rpc.NewServer() - RegisterPostProcessor(server, p) - address := serveSingleConn(server) + client, server := testClientServer(t) + defer client.Close() + defer server.Close() + server.RegisterPostProcessor(p) - // Create the client over RPC and run some methods to verify it works - client, err := rpc.Dial("tcp", address) - if err != nil { - t.Fatalf("Error connecting to rpc: %s", err) - } + ppClient := client.PostProcessor() // Test Configure config := 42 - pClient := PostProcessor(client) - err = pClient.Configure(config) + err := ppClient.Configure(config) if err != nil { t.Fatalf("error: %s", err) } @@ -62,9 +59,11 @@ func TestPostProcessorRPC(t *testing.T) { } // Test PostProcess - a := new(testArtifact) + a := &packer.MockArtifact{ + IdValue: "ppTestId", + } ui := new(testUi) - artifact, _, err := pClient.PostProcess(ui, a) + artifact, _, err := ppClient.PostProcess(ui, a) if err != nil { t.Fatalf("err: %s", err) } @@ -73,18 +72,18 @@ func TestPostProcessorRPC(t *testing.T) { t.Fatal("postprocess should be called") } - if p.ppArtifact.BuilderId() != "bid" { - t.Fatal("unknown artifact") + if p.ppArtifactId != "ppTestId" { + t.Fatalf("unknown artifact: %s", p.ppArtifact.Id()) } - if artifact.BuilderId() != "bid" { - t.Fatal("unknown result artifact") + if artifact.Id() != "id" { + t.Fatalf("unknown artifact: %s", artifact.Id()) } } func TestPostProcessor_Implements(t *testing.T) { var raw interface{} - raw = PostProcessor(nil) + raw = new(postProcessor) if _, ok := raw.(packer.PostProcessor); !ok { t.Fatal("not a postprocessor") } diff --git a/packer/rpc/provisioner.go b/packer/rpc/provisioner.go index 4cd329d6b..7ddb32ec3 100644 --- a/packer/rpc/provisioner.go +++ b/packer/rpc/provisioner.go @@ -10,25 +10,20 @@ import ( // executed over an RPC connection. type provisioner struct { client *rpc.Client + mux *MuxConn } // ProvisionerServer wraps a packer.Provisioner implementation and makes it // exportable as part of a Golang RPC server. type ProvisionerServer struct { - p packer.Provisioner + p packer.Provisioner + mux *MuxConn } type ProvisionerPrepareArgs struct { Configs []interface{} } -type ProvisionerProvisionArgs struct { - RPCAddress string -} - -func Provisioner(client *rpc.Client) *provisioner { - return &provisioner{client} -} func (p *provisioner) Prepare(configs ...interface{}) (err error) { args := &ProvisionerPrepareArgs{configs} if cerr := p.client.Call("Provisioner.Prepare", args, &err); cerr != nil { @@ -39,13 +34,13 @@ func (p *provisioner) Prepare(configs ...interface{}) (err error) { } func (p *provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { - // TODO: Error handling - server := rpc.NewServer() - RegisterCommunicator(server, comm) - RegisterUi(server, ui) + nextId := p.mux.NextId() + server := NewServerWithMux(p.mux, nextId) + server.RegisterCommunicator(comm) + server.RegisterUi(ui) + go server.Serve() - args := &ProvisionerProvisionArgs{serveSingleConn(server)} - return p.client.Call("Provisioner.Provision", args, new(interface{})) + return p.client.Call("Provisioner.Provision", nextId, new(interface{})) } func (p *provisioner) Cancel() { @@ -64,16 +59,14 @@ func (p *ProvisionerServer) Prepare(args *ProvisionerPrepareArgs, reply *error) return nil } -func (p *ProvisionerServer) Provision(args *ProvisionerProvisionArgs, reply *interface{}) error { - client, err := rpcDial(args.RPCAddress) +func (p *ProvisionerServer) Provision(streamId uint32, reply *interface{}) error { + client, err := NewClientWithMux(p.mux, streamId) if err != nil { - return err + return NewBasicError(err) } + defer client.Close() - comm := Communicator(client) - ui := &Ui{client} - - if err := p.p.Provision(ui, comm); err != nil { + if err := p.p.Provision(client.Ui(), client.Communicator()); err != nil { return NewBasicError(err) } diff --git a/packer/rpc/provisioner_test.go b/packer/rpc/provisioner_test.go index 7e281a1b5..a17e0694c 100644 --- a/packer/rpc/provisioner_test.go +++ b/packer/rpc/provisioner_test.go @@ -2,7 +2,6 @@ package rpc import ( "github.com/mitchellh/packer/packer" - "net/rpc" "reflect" "testing" ) @@ -12,19 +11,14 @@ func TestProvisionerRPC(t *testing.T) { p := new(packer.MockProvisioner) // Start the server - server := rpc.NewServer() - RegisterProvisioner(server, p) - address := serveSingleConn(server) - - // Create the client over RPC and run some methods to verify it works - client, err := rpc.Dial("tcp", address) - if err != nil { - t.Fatalf("err: %s", err) - } + client, server := testClientServer(t) + defer client.Close() + defer server.Close() + server.RegisterProvisioner(p) + pClient := client.Provisioner() // Test Prepare config := 42 - pClient := Provisioner(client) pClient.Prepare(config) if !p.PrepCalled { t.Fatal("should be called") @@ -41,11 +35,6 @@ func TestProvisionerRPC(t *testing.T) { t.Fatal("should be called") } - p.ProvUi.Say("foo") - if !ui.sayCalled { - t.Fatal("should be called") - } - // Test Cancel pClient.Cancel() if !p.CancelCalled { @@ -54,5 +43,5 @@ func TestProvisionerRPC(t *testing.T) { } func TestProvisioner_Implements(t *testing.T) { - var _ packer.Provisioner = Provisioner(nil) + var _ packer.Provisioner = new(provisioner) } diff --git a/packer/rpc/server.go b/packer/rpc/server.go index f6098ea7a..bf194bd6c 100644 --- a/packer/rpc/server.go +++ b/packer/rpc/server.go @@ -1,88 +1,155 @@ package rpc import ( + "fmt" "github.com/mitchellh/packer/packer" + "io" + "log" "net/rpc" + "sync/atomic" ) -// Registers the appropriate endpoint on an RPC server to serve an -// Artifact. -func RegisterArtifact(s *rpc.Server, a packer.Artifact) { - s.RegisterName("Artifact", &ArtifactServer{a}) +var endpointId uint64 + +const ( + DefaultArtifactEndpoint string = "Artifact" + DefaultBuildEndpoint = "Build" + DefaultBuilderEndpoint = "Builder" + DefaultCacheEndpoint = "Cache" + DefaultCommandEndpoint = "Command" + DefaultCommunicatorEndpoint = "Communicator" + DefaultEnvironmentEndpoint = "Environment" + DefaultHookEndpoint = "Hook" + DefaultPostProcessorEndpoint = "PostProcessor" + DefaultProvisionerEndpoint = "Provisioner" + DefaultUiEndpoint = "Ui" +) + +// Server represents an RPC server for Packer. This must be paired on +// the other side with a Client. +type Server struct { + mux *MuxConn + streamId uint32 + server *rpc.Server } -// Registers the appropriate endpoint on an RPC server to serve a -// Packer Build. -func RegisterBuild(s *rpc.Server, b packer.Build) { - s.RegisterName("Build", &BuildServer{b}) +// NewServer returns a new Packer RPC server. +func NewServer(conn io.ReadWriteCloser) *Server { + return NewServerWithMux(NewMuxConn(conn), 0) } -// Registers the appropriate endpoint on an RPC server to serve a -// Packer Builder. -func RegisterBuilder(s *rpc.Server, b packer.Builder) { - s.RegisterName("Builder", &BuilderServer{b}) +func NewServerWithMux(mux *MuxConn, streamId uint32) *Server { + return &Server{ + mux: mux, + streamId: streamId, + server: rpc.NewServer(), + } } -// Registers the appropriate endpoint on an RPC server to serve a -// Packer Cache. -func RegisterCache(s *rpc.Server, c packer.Cache) { - s.RegisterName("Cache", &CacheServer{c}) +func (s *Server) Close() error { + return s.mux.Close() } -// Registers the appropriate endpoint on an RPC server to serve a -// Packer Command. -func RegisterCommand(s *rpc.Server, c packer.Command) { - s.RegisterName("Command", &CommandServer{c}) +func (s *Server) RegisterArtifact(a packer.Artifact) { + s.server.RegisterName(DefaultArtifactEndpoint, &ArtifactServer{ + artifact: a, + }) } -// Registers the appropriate endpoint on an RPC server to serve a -// Packer Communicator. -func RegisterCommunicator(s *rpc.Server, c packer.Communicator) { - s.RegisterName("Communicator", &CommunicatorServer{c}) +func (s *Server) RegisterBuild(b packer.Build) { + s.server.RegisterName(DefaultBuildEndpoint, &BuildServer{ + build: b, + mux: s.mux, + }) } -// Registers the appropriate endpoint on an RPC server to serve a -// Packer Environment -func RegisterEnvironment(s *rpc.Server, e packer.Environment) { - s.RegisterName("Environment", &EnvironmentServer{e}) +func (s *Server) RegisterBuilder(b packer.Builder) { + s.server.RegisterName(DefaultBuilderEndpoint, &BuilderServer{ + builder: b, + mux: s.mux, + }) } -// Registers the appropriate endpoint on an RPC server to serve a -// Hook. -func RegisterHook(s *rpc.Server, hook packer.Hook) { - s.RegisterName("Hook", &HookServer{hook}) +func (s *Server) RegisterCache(c packer.Cache) { + s.server.RegisterName(DefaultCacheEndpoint, &CacheServer{ + cache: c, + }) } -// Registers the appropriate endpoing on an RPC server to serve a -// PostProcessor. -func RegisterPostProcessor(s *rpc.Server, p packer.PostProcessor) { - s.RegisterName("PostProcessor", &PostProcessorServer{p}) +func (s *Server) RegisterCommand(c packer.Command) { + s.server.RegisterName(DefaultCommandEndpoint, &CommandServer{ + command: c, + mux: s.mux, + }) } -// Registers the appropriate endpoint on an RPC server to serve a packer.Provisioner -func RegisterProvisioner(s *rpc.Server, p packer.Provisioner) { - s.RegisterName("Provisioner", &ProvisionerServer{p}) +func (s *Server) RegisterCommunicator(c packer.Communicator) { + s.server.RegisterName(DefaultCommunicatorEndpoint, &CommunicatorServer{ + c: c, + mux: s.mux, + }) } -// Registers the appropriate endpoint on an RPC server to serve a -// Packer UI -func RegisterUi(s *rpc.Server, ui packer.Ui) { - s.RegisterName("Ui", &UiServer{ui}) +func (s *Server) RegisterEnvironment(b packer.Environment) { + s.server.RegisterName(DefaultEnvironmentEndpoint, &EnvironmentServer{ + env: b, + mux: s.mux, + }) } -func serveSingleConn(s *rpc.Server) string { - l := netListenerInRange(portRangeMin, portRangeMax) - - // Accept a single connection in a goroutine and then exit - go func() { - defer l.Close() - conn, err := l.Accept() - if err != nil { - panic(err) - } - - s.ServeConn(conn) - }() - - return l.Addr().String() +func (s *Server) RegisterHook(h packer.Hook) { + s.server.RegisterName(DefaultHookEndpoint, &HookServer{ + hook: h, + mux: s.mux, + }) +} + +func (s *Server) RegisterPostProcessor(p packer.PostProcessor) { + s.server.RegisterName(DefaultPostProcessorEndpoint, &PostProcessorServer{ + mux: s.mux, + p: p, + }) +} + +func (s *Server) RegisterProvisioner(p packer.Provisioner) { + s.server.RegisterName(DefaultProvisionerEndpoint, &ProvisionerServer{ + mux: s.mux, + p: p, + }) +} + +func (s *Server) RegisterUi(ui packer.Ui) { + s.server.RegisterName(DefaultUiEndpoint, &UiServer{ + ui: ui, + }) +} + +// ServeConn serves a single connection over the RPC server. It is up +// to the caller to obtain a proper io.ReadWriteCloser. +func (s *Server) Serve() { + // Accept a connection on stream ID 0, which is always used for + // normal client to server connections. + stream, err := s.mux.Accept(s.streamId) + defer stream.Close() + if err != nil { + log.Printf("[ERR] Error retrieving stream for serving: %s", err) + return + } + + s.server.ServeConn(stream) +} + +// registerComponent registers a single Packer RPC component onto +// the RPC server. If id is true, then a unique ID number will be appended +// onto the end of the endpoint. +// +// The endpoint name is returned. +func registerComponent(server *rpc.Server, name string, rcvr interface{}, id bool) string { + endpoint := name + if id { + fmt.Sprintf("%s.%d", endpoint, atomic.AddUint64(&endpointId, 1)) + } + + server.RegisterName(endpoint, rcvr) + return endpoint } diff --git a/packer/rpc/ui.go b/packer/rpc/ui.go index 4d7ccc57f..da032160a 100644 --- a/packer/rpc/ui.go +++ b/packer/rpc/ui.go @@ -9,7 +9,8 @@ import ( // An implementation of packer.Ui where the Ui is actually executed // over an RPC connection. type Ui struct { - client *rpc.Client + client *rpc.Client + endpoint string } // UiServer wraps a packer.Ui implementation and makes it exportable diff --git a/packer/rpc/ui_test.go b/packer/rpc/ui_test.go index 8a85c5573..4dd2d7036 100644 --- a/packer/rpc/ui_test.go +++ b/packer/rpc/ui_test.go @@ -1,7 +1,6 @@ package rpc import ( - "net/rpc" "reflect" "testing" ) @@ -52,17 +51,12 @@ func TestUiRPC(t *testing.T) { ui := new(testUi) // Start the RPC server - server := rpc.NewServer() - RegisterUi(server, ui) - address := serveSingleConn(server) + client, server := testClientServer(t) + defer client.Close() + defer server.Close() + server.RegisterUi(ui) - // Create the client over RPC and run some methods to verify it works - client, err := rpc.Dial("tcp", address) - if err != nil { - panic(err) - } - - uiClient := &Ui{client} + uiClient := client.Ui() // Basic error and say tests result, err := uiClient.Ask("query") diff --git a/plugin/builder-amazon-chroot/main.go b/plugin/builder-amazon-chroot/main.go index b7d71df44..13fa295f7 100644 --- a/plugin/builder-amazon-chroot/main.go +++ b/plugin/builder-amazon-chroot/main.go @@ -6,5 +6,10 @@ import ( ) func main() { - plugin.ServeBuilder(new(chroot.Builder)) + server, err := plugin.Server() + if err != nil { + panic(err) + } + server.RegisterBuilder(new(chroot.Builder)) + server.Serve() } diff --git a/plugin/builder-amazon-ebs/main.go b/plugin/builder-amazon-ebs/main.go index f0598d749..ba6a6e380 100644 --- a/plugin/builder-amazon-ebs/main.go +++ b/plugin/builder-amazon-ebs/main.go @@ -6,5 +6,10 @@ import ( ) func main() { - plugin.ServeBuilder(new(ebs.Builder)) + server, err := plugin.Server() + if err != nil { + panic(err) + } + server.RegisterBuilder(new(ebs.Builder)) + server.Serve() } diff --git a/plugin/builder-amazon-instance/main.go b/plugin/builder-amazon-instance/main.go index fc1c66b40..e87654107 100644 --- a/plugin/builder-amazon-instance/main.go +++ b/plugin/builder-amazon-instance/main.go @@ -6,5 +6,10 @@ import ( ) func main() { - plugin.ServeBuilder(new(instance.Builder)) + server, err := plugin.Server() + if err != nil { + panic(err) + } + server.RegisterBuilder(new(instance.Builder)) + server.Serve() } diff --git a/plugin/builder-digitalocean/main.go b/plugin/builder-digitalocean/main.go index be8d706f0..a41ed3d65 100644 --- a/plugin/builder-digitalocean/main.go +++ b/plugin/builder-digitalocean/main.go @@ -6,5 +6,10 @@ import ( ) func main() { - plugin.ServeBuilder(new(digitalocean.Builder)) + server, err := plugin.Server() + if err != nil { + panic(err) + } + server.RegisterBuilder(new(digitalocean.Builder)) + server.Serve() } diff --git a/plugin/builder-docker/main.go b/plugin/builder-docker/main.go index a5b69de0b..1ed06dd5e 100644 --- a/plugin/builder-docker/main.go +++ b/plugin/builder-docker/main.go @@ -6,5 +6,10 @@ import ( ) func main() { - plugin.ServeBuilder(new(docker.Builder)) + server, err := plugin.Server() + if err != nil { + panic(err) + } + server.RegisterBuilder(new(docker.Builder)) + server.Serve() } diff --git a/plugin/builder-openstack/main.go b/plugin/builder-openstack/main.go index 076011003..7753e04ff 100644 --- a/plugin/builder-openstack/main.go +++ b/plugin/builder-openstack/main.go @@ -6,5 +6,10 @@ import ( ) func main() { - plugin.ServeBuilder(new(openstack.Builder)) + server, err := plugin.Server() + if err != nil { + panic(err) + } + server.RegisterBuilder(new(openstack.Builder)) + server.Serve() } diff --git a/plugin/builder-qemu/main.go b/plugin/builder-qemu/main.go index 710742bad..b823ba501 100644 --- a/plugin/builder-qemu/main.go +++ b/plugin/builder-qemu/main.go @@ -6,5 +6,10 @@ import ( ) func main() { - plugin.ServeBuilder(new(qemu.Builder)) + server, err := plugin.Server() + if err != nil { + panic(err) + } + server.RegisterBuilder(new(qemu.Builder)) + server.Serve() } diff --git a/plugin/builder-virtualbox/main.go b/plugin/builder-virtualbox/main.go index 1428c5b1d..136b2698e 100644 --- a/plugin/builder-virtualbox/main.go +++ b/plugin/builder-virtualbox/main.go @@ -6,5 +6,10 @@ import ( ) func main() { - plugin.ServeBuilder(new(virtualbox.Builder)) + server, err := plugin.Server() + if err != nil { + panic(err) + } + server.RegisterBuilder(new(virtualbox.Builder)) + server.Serve() } diff --git a/plugin/builder-vmware/main.go b/plugin/builder-vmware/main.go index a449751b8..14a10cf5c 100644 --- a/plugin/builder-vmware/main.go +++ b/plugin/builder-vmware/main.go @@ -6,5 +6,10 @@ import ( ) func main() { - plugin.ServeBuilder(new(vmware.Builder)) + server, err := plugin.Server() + if err != nil { + panic(err) + } + server.RegisterBuilder(new(vmware.Builder)) + server.Serve() } diff --git a/plugin/command-build/main.go b/plugin/command-build/main.go index e6a4397fd..1285e1e7a 100644 --- a/plugin/command-build/main.go +++ b/plugin/command-build/main.go @@ -6,5 +6,10 @@ import ( ) func main() { - plugin.ServeCommand(new(build.Command)) + server, err := plugin.Server() + if err != nil { + panic(err) + } + server.RegisterCommand(new(build.Command)) + server.Serve() } diff --git a/plugin/command-fix/main.go b/plugin/command-fix/main.go index 8ae6d42a0..03677674d 100644 --- a/plugin/command-fix/main.go +++ b/plugin/command-fix/main.go @@ -6,5 +6,10 @@ import ( ) func main() { - plugin.ServeCommand(new(fix.Command)) + server, err := plugin.Server() + if err != nil { + panic(err) + } + server.RegisterCommand(new(fix.Command)) + server.Serve() } diff --git a/plugin/command-inspect/main.go b/plugin/command-inspect/main.go index 6d2d7d3d5..9aeedc34c 100644 --- a/plugin/command-inspect/main.go +++ b/plugin/command-inspect/main.go @@ -6,5 +6,10 @@ import ( ) func main() { - plugin.ServeCommand(new(inspect.Command)) + server, err := plugin.Server() + if err != nil { + panic(err) + } + server.RegisterCommand(new(inspect.Command)) + server.Serve() } diff --git a/plugin/command-validate/main.go b/plugin/command-validate/main.go index af093294e..1105ed75e 100644 --- a/plugin/command-validate/main.go +++ b/plugin/command-validate/main.go @@ -6,5 +6,10 @@ import ( ) func main() { - plugin.ServeCommand(new(validate.Command)) + server, err := plugin.Server() + if err != nil { + panic(err) + } + server.RegisterCommand(new(validate.Command)) + server.Serve() } diff --git a/plugin/post-processor-vagrant/main.go b/plugin/post-processor-vagrant/main.go index 0224d26d1..cf84e3ff1 100644 --- a/plugin/post-processor-vagrant/main.go +++ b/plugin/post-processor-vagrant/main.go @@ -6,5 +6,10 @@ import ( ) func main() { - plugin.ServePostProcessor(new(vagrant.PostProcessor)) + server, err := plugin.Server() + if err != nil { + panic(err) + } + server.RegisterPostProcessor(new(vagrant.PostProcessor)) + server.Serve() } diff --git a/plugin/post-processor-vsphere/main.go b/plugin/post-processor-vsphere/main.go index 22317ab5d..10e7a8d40 100644 --- a/plugin/post-processor-vsphere/main.go +++ b/plugin/post-processor-vsphere/main.go @@ -6,5 +6,10 @@ import ( ) func main() { - plugin.ServePostProcessor(new(vsphere.PostProcessor)) + server, err := plugin.Server() + if err != nil { + panic(err) + } + server.RegisterPostProcessor(new(vsphere.PostProcessor)) + server.Serve() } diff --git a/plugin/provisioner-ansible-local/main.go b/plugin/provisioner-ansible-local/main.go index 0caf0427f..5477b1b55 100644 --- a/plugin/provisioner-ansible-local/main.go +++ b/plugin/provisioner-ansible-local/main.go @@ -6,5 +6,10 @@ import ( ) func main() { - plugin.ServeProvisioner(new(ansiblelocal.Provisioner)) + server, err := plugin.Server() + if err != nil { + panic(err) + } + server.RegisterProvisioner(new(ansiblelocal.Provisioner)) + server.Serve() } diff --git a/plugin/provisioner-chef-solo/main.go b/plugin/provisioner-chef-solo/main.go index 4057f9925..e8a5bf11e 100644 --- a/plugin/provisioner-chef-solo/main.go +++ b/plugin/provisioner-chef-solo/main.go @@ -6,5 +6,10 @@ import ( ) func main() { - plugin.ServeProvisioner(new(chefsolo.Provisioner)) + server, err := plugin.Server() + if err != nil { + panic(err) + } + server.RegisterProvisioner(new(chefsolo.Provisioner)) + server.Serve() } diff --git a/plugin/provisioner-file/main.go b/plugin/provisioner-file/main.go index 1f5a63413..1b746e050 100644 --- a/plugin/provisioner-file/main.go +++ b/plugin/provisioner-file/main.go @@ -6,5 +6,10 @@ import ( ) func main() { - plugin.ServeProvisioner(new(file.Provisioner)) + server, err := plugin.Server() + if err != nil { + panic(err) + } + server.RegisterProvisioner(new(file.Provisioner)) + server.Serve() } diff --git a/plugin/provisioner-puppet-masterless/main.go b/plugin/provisioner-puppet-masterless/main.go index c510cbb78..45dba0a31 100644 --- a/plugin/provisioner-puppet-masterless/main.go +++ b/plugin/provisioner-puppet-masterless/main.go @@ -6,5 +6,10 @@ import ( ) func main() { - plugin.ServeProvisioner(new(puppetmasterless.Provisioner)) + server, err := plugin.Server() + if err != nil { + panic(err) + } + server.RegisterProvisioner(new(puppetmasterless.Provisioner)) + server.Serve() } diff --git a/plugin/provisioner-salt-masterless/main.go b/plugin/provisioner-salt-masterless/main.go index 53ccd771a..584c4657e 100644 --- a/plugin/provisioner-salt-masterless/main.go +++ b/plugin/provisioner-salt-masterless/main.go @@ -6,5 +6,10 @@ import ( ) func main() { - plugin.ServeProvisioner(new(saltmasterless.Provisioner)) + server, err := plugin.Server() + if err != nil { + panic(err) + } + server.RegisterProvisioner(new(saltmasterless.Provisioner)) + server.Serve() } diff --git a/plugin/provisioner-shell/main.go b/plugin/provisioner-shell/main.go index 07d18472e..03966b67b 100644 --- a/plugin/provisioner-shell/main.go +++ b/plugin/provisioner-shell/main.go @@ -6,5 +6,10 @@ import ( ) func main() { - plugin.ServeProvisioner(new(shell.Provisioner)) + server, err := plugin.Server() + if err != nil { + panic(err) + } + server.RegisterProvisioner(new(shell.Provisioner)) + server.Serve() }