diff --git a/common/net/rpc/alias.go b/common/net/rpc/alias.go new file mode 100644 index 000000000..0abe9307e --- /dev/null +++ b/common/net/rpc/alias.go @@ -0,0 +1,37 @@ +package rpc + +import ( + baserpc "net/rpc" + + "github.com/keegancsmith/rpc" +) + +type Server = rpc.Server +type Client = rpc.Client +type Request = rpc.Request +type Response = rpc.Response + +type ClientCodec interface { + WriteRequest(*baserpc.Request, interface{}) error + ReadResponseHeader(*baserpc.Response) error + ReadResponseBody(interface{}) error + + Close() error +} + +type ServerCodec interface { + ReadRequestHeader(*baserpc.Request) error + ReadRequestBody(interface{}) error + WriteResponse(*baserpc.Response, interface{}) error + + // Close can be called multiple times and must be idempotent. + Close() error +} + +func NewClientWithCodec(codec rpc.ClientCodec) *Client { + return rpc.NewClientWithCodec(codec) +} + +func NewServer() *Server { + return rpc.NewServer() +} diff --git a/common/net/rpc/codec/alias.go b/common/net/rpc/codec/alias.go new file mode 100644 index 000000000..77c2693f8 --- /dev/null +++ b/common/net/rpc/codec/alias.go @@ -0,0 +1,51 @@ +package codec + +import ( + "io" + baserpc "net/rpc" + + "github.com/hashicorp/packer/common/net/rpc" + "github.com/ugorji/go/codec" +) + +type Handle = codec.Handle + +var msgpackHandle = &codec.MsgpackHandle{ + RawToString: true, + WriteExt: true, +} + +type serverCodec struct { + baserpc.ServerCodec +} + +func (s *serverCodec) ReadRequestHeader(r *rpc.Request) error { + return s.ServerCodec.ReadRequestHeader(r) +} + +func (s *serverCodec) WriteResponse(r *baserpc.Response, v interface{}) error { + return s.ServerCodec.WriteResponse(r, v) +} + +func MsgpackServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec { + c := codec.GoRpc.ServerCodec(conn, msgpackHandle) + return &serverCodec{c} +} + +type clientCodec struct { + baserpc.ClientCodec +} + +func (c *clientCodec) WriteRequest(req *baserpc.Request, v interface{}) error { + return c.ClientCodec.WriteRequest(req, v) +} + +func (c *clientCodec) ReadResponseHeader(res *rpc.Response) error { + r := baserpc.Response(*res) + return c.ClientCodec.ReadResponseHeader() +} + +func MsgpackClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec { + c := codec.GoRpc.ClientCodec(conn, msgpackHandle) + return &clientCodec{c} +} diff --git a/packer/rpc/artifact.go b/packer/rpc/artifact.go index b1166afa7..538a2157c 100644 --- a/packer/rpc/artifact.go +++ b/packer/rpc/artifact.go @@ -1,7 +1,9 @@ package rpc import ( - "net/rpc" + "context" + + "github.com/hashicorp/packer/common/net/rpc" "github.com/hashicorp/packer/packer" ) @@ -20,33 +22,40 @@ type ArtifactServer struct { } func (a *artifact) BuilderId() (result string) { - a.client.Call(a.endpoint+".BuilderId", new(interface{}), &result) + ctx := context.TODO() + a.client.Call(ctx, a.endpoint+".BuilderId", new(interface{}), &result) return } func (a *artifact) Files() (result []string) { - a.client.Call(a.endpoint+".Files", new(interface{}), &result) + ctx := context.TODO() + a.client.Call(ctx, a.endpoint+".Files", new(interface{}), &result) return } func (a *artifact) Id() (result string) { - a.client.Call(a.endpoint+".Id", new(interface{}), &result) + ctx := context.TODO() + a.client.Call(ctx, a.endpoint+".Id", new(interface{}), &result) return } func (a *artifact) String() (result string) { - a.client.Call(a.endpoint+".String", new(interface{}), &result) + ctx := context.TODO() + a.client.Call(ctx, a.endpoint+".String", new(interface{}), &result) return } func (a *artifact) State(name string) (result interface{}) { - a.client.Call(a.endpoint+".State", name, &result) + ctx := context.TODO() + a.client.Call(ctx, a.endpoint+".State", name, &result) return } func (a *artifact) Destroy() error { + ctx := context.TODO() + var result error - if err := a.client.Call(a.endpoint+".Destroy", new(interface{}), &result); err != nil { + if err := a.client.Call(ctx, a.endpoint+".Destroy", new(interface{}), &result); err != nil { return err } diff --git a/packer/rpc/build.go b/packer/rpc/build.go index 916ea7138..c91074eb4 100644 --- a/packer/rpc/build.go +++ b/packer/rpc/build.go @@ -2,7 +2,8 @@ package rpc import ( "context" - "net/rpc" + + "github.com/hashicorp/packer/common/net/rpc" "github.com/hashicorp/packer/packer" ) @@ -27,13 +28,15 @@ type BuildPrepareResponse struct { } func (b *build) Name() (result string) { - b.client.Call("Build.Name", new(interface{}), &result) + ctx := context.TODO() + b.client.Call(ctx, "Build.Name", new(interface{}), &result) return } func (b *build) Prepare() ([]string, error) { + ctx := context.TODO() var resp BuildPrepareResponse - if cerr := b.client.Call("Build.Prepare", new(interface{}), &resp); cerr != nil { + if cerr := b.client.Call(ctx, "Build.Prepare", new(interface{}), &resp); cerr != nil { return nil, cerr } var err error = nil @@ -51,7 +54,7 @@ func (b *build) Run(ctx context.Context, ui packer.Ui) ([]packer.Artifact, error go server.Serve() var result []uint32 - if err := b.client.Call("Build.Run", nextId, &result); err != nil { + if err := b.client.Call(ctx, "Build.Run", nextId, &result); err != nil { return nil, err } @@ -69,25 +72,22 @@ func (b *build) Run(ctx context.Context, ui packer.Ui) ([]packer.Artifact, error } func (b *build) SetDebug(val bool) { - if err := b.client.Call("Build.SetDebug", val, new(interface{})); err != nil { + ctx := context.TODO() + if err := b.client.Call(ctx, "Build.SetDebug", val, new(interface{})); err != nil { panic(err) } } func (b *build) SetForce(val bool) { - if err := b.client.Call("Build.SetForce", val, new(interface{})); err != nil { + ctx := context.TODO() + if err := b.client.Call(ctx, "Build.SetForce", val, new(interface{})); err != nil { panic(err) } } func (b *build) SetOnError(val string) { - if err := b.client.Call("Build.SetOnError", val, new(interface{})); err != nil { - panic(err) - } -} - -func (b *build) Cancel() { - if err := b.client.Call("Build.Cancel", new(interface{}), new(interface{})); err != nil { + ctx := context.TODO() + if err := b.client.Call(ctx, "Build.SetOnError", val, new(interface{})); err != nil { panic(err) } } diff --git a/packer/rpc/builder.go b/packer/rpc/builder.go index b8ba27b35..554e88cbd 100644 --- a/packer/rpc/builder.go +++ b/packer/rpc/builder.go @@ -2,7 +2,8 @@ package rpc import ( "context" - "net/rpc" + + "github.com/hashicorp/packer/common/net/rpc" "github.com/hashicorp/packer/packer" ) @@ -32,7 +33,8 @@ type BuilderPrepareResponse struct { func (b *builder) Prepare(config ...interface{}) ([]string, error) { var resp BuilderPrepareResponse - cerr := b.client.Call("Builder.Prepare", &BuilderPrepareArgs{config}, &resp) + ctx := context.TODO() + cerr := b.client.Call(ctx, "Builder.Prepare", &BuilderPrepareArgs{config}, &resp) if cerr != nil { return nil, cerr } @@ -52,7 +54,7 @@ func (b *builder) Run(ctx context.Context, ui packer.Ui, hook packer.Hook) (pack go server.Serve() var responseId uint32 - if err := b.client.Call("Builder.Run", nextId, &responseId); err != nil { + if err := b.client.Call(ctx, "Builder.Run", nextId, &responseId); err != nil { return nil, err } diff --git a/packer/rpc/client.go b/packer/rpc/client.go index 53f88d52f..732421a23 100644 --- a/packer/rpc/client.go +++ b/packer/rpc/client.go @@ -3,10 +3,12 @@ package rpc import ( "io" "log" - "net/rpc" + + "github.com/hashicorp/packer/common/net/rpc/codec" + + "github.com/hashicorp/packer/common/net/rpc" "github.com/hashicorp/packer/packer" - "github.com/ugorji/go/codec" ) // Client is the client end that communicates with a Packer RPC server. @@ -41,11 +43,13 @@ func newClientWithMux(mux *muxBroker, streamId uint32) (*Client, error) { return nil, err } - h := &codec.MsgpackHandle{ - RawToString: true, - WriteExt: true, - } - clientCodec := codec.GoRpc.ClientCodec(clientConn, h) + // h := &codec.MsgpackHandle{ + // RawToString: true, + // WriteExt: true, + // } + var clientCodec rpc.ClientCodec + + clientCodec = codec.MsgpackClientCodec(clientConn) return &Client{ mux: mux, diff --git a/packer/rpc/communicator.go b/packer/rpc/communicator.go index fbfe0c538..937193ac1 100644 --- a/packer/rpc/communicator.go +++ b/packer/rpc/communicator.go @@ -1,13 +1,15 @@ package rpc import ( + "context" "encoding/gob" "io" "log" - "net/rpc" "os" "sync" + "github.com/hashicorp/packer/common/net/rpc" + "github.com/hashicorp/packer/packer" ) @@ -122,7 +124,8 @@ func (c *communicator) Start(cmd *packer.RemoteCmd) (err error) { cmd.SetExited(finished.ExitStatus) }() - err = c.client.Call("Communicator.Start", &args, new(interface{})) + ctx := context.TODO() + err = c.client.Call(ctx, "Communicator.Start", &args, new(interface{})) return } @@ -140,8 +143,8 @@ func (c *communicator) Upload(path string, r io.Reader, fi *os.FileInfo) (err er args.FileInfo = NewFileInfo(*fi) } - err = c.client.Call("Communicator.Upload", &args, new(interface{})) - return + ctx := context.TODO() + return c.client.Call(ctx, "Communicator.Upload", &args, new(interface{})) } func (c *communicator) UploadDir(dst string, src string, exclude []string) error { @@ -152,7 +155,8 @@ func (c *communicator) UploadDir(dst string, src string, exclude []string) error } var reply error - err := c.client.Call("Communicator.UploadDir", args, &reply) + ctx := context.TODO() + err := c.client.Call(ctx, "Communicator.UploadDir", args, &reply) if err == nil { err = reply } @@ -168,7 +172,8 @@ func (c *communicator) DownloadDir(src string, dst string, exclude []string) err } var reply error - err := c.client.Call("Communicator.DownloadDir", args, &reply) + ctx := context.TODO() + err := c.client.Call(ctx, "Communicator.DownloadDir", args, &reply) if err == nil { err = reply } @@ -192,7 +197,8 @@ func (c *communicator) Download(path string, w io.Writer) (err error) { } // Start sending data to the RPC server - err = c.client.Call("Communicator.Download", &args, new(interface{})) + ctx := context.TODO() + err = c.client.Call(ctx, "Communicator.Download", &args, new(interface{})) // Wait for the RPC server to finish receiving the data before we return <-waitServer diff --git a/packer/rpc/hook.go b/packer/rpc/hook.go index 1ca368e45..c82e1ff5e 100644 --- a/packer/rpc/hook.go +++ b/packer/rpc/hook.go @@ -2,8 +2,8 @@ package rpc import ( "context" - "log" - "net/rpc" + + "github.com/hashicorp/packer/common/net/rpc" "github.com/hashicorp/packer/packer" ) @@ -41,14 +41,7 @@ func (h *hook) Run(ctx context.Context, name string, ui packer.Ui, comm packer.C StreamId: nextId, } - return h.client.Call("Hook.Run", &args, new(interface{})) -} - -func (h *hook) Cancel() { - err := h.client.Call("Hook.Cancel", new(interface{}), new(interface{})) - if err != nil { - log.Printf("Hook.Cancel error: %s", err) - } + return h.client.Call(ctx, "Hook.Run", &args, new(interface{})) } func (h *HookServer) Run(ctx context.Context, args *HookRunArgs, reply *interface{}) error { diff --git a/packer/rpc/post_processor.go b/packer/rpc/post_processor.go index 1d068ae50..cfb7441d0 100644 --- a/packer/rpc/post_processor.go +++ b/packer/rpc/post_processor.go @@ -2,7 +2,8 @@ package rpc import ( "context" - "net/rpc" + + "github.com/hashicorp/packer/common/net/rpc" "github.com/hashicorp/packer/packer" ) @@ -33,7 +34,8 @@ type PostProcessorProcessResponse struct { func (p *postProcessor) Configure(raw ...interface{}) (err error) { args := &PostProcessorConfigureArgs{Configs: raw} - if cerr := p.client.Call("PostProcessor.Configure", args, new(interface{})); cerr != nil { + ctx := context.TODO() + if cerr := p.client.Call(ctx, "PostProcessor.Configure", args, new(interface{})); cerr != nil { err = cerr } @@ -48,7 +50,7 @@ func (p *postProcessor) PostProcess(ctx context.Context, ui packer.Ui, a packer. go server.Serve() var response PostProcessorProcessResponse - if err := p.client.Call("PostProcessor.PostProcess", nextId, &response); err != nil { + if err := p.client.Call(ctx, "PostProcessor.PostProcess", nextId, &response); err != nil { return nil, false, err } diff --git a/packer/rpc/provisioner.go b/packer/rpc/provisioner.go index 1a66c6057..4604f3031 100644 --- a/packer/rpc/provisioner.go +++ b/packer/rpc/provisioner.go @@ -2,7 +2,8 @@ package rpc import ( "context" - "net/rpc" + + "github.com/hashicorp/packer/common/net/rpc" "github.com/hashicorp/packer/packer" ) @@ -27,7 +28,8 @@ type ProvisionerPrepareArgs struct { func (p *provisioner) Prepare(configs ...interface{}) (err error) { args := &ProvisionerPrepareArgs{configs} - if cerr := p.client.Call("Provisioner.Prepare", args, new(interface{})); cerr != nil { + ctx := context.TODO() + if cerr := p.client.Call(ctx, "Provisioner.Prepare", args, new(interface{})); cerr != nil { err = cerr } @@ -41,7 +43,7 @@ func (p *provisioner) Provision(ctx context.Context, ui packer.Ui, comm packer.C server.RegisterUi(ui) go server.Serve() - return p.client.Call("Provisioner.Provision", nextId, new(interface{})) + return p.client.Call(ctx, "Provisioner.Provision", nextId, new(interface{})) } func (p *ProvisionerServer) Prepare(_ context.Context, args *ProvisionerPrepareArgs, reply *interface{}) error { diff --git a/packer/rpc/server.go b/packer/rpc/server.go index c02f75409..497083c52 100644 --- a/packer/rpc/server.go +++ b/packer/rpc/server.go @@ -3,10 +3,12 @@ package rpc import ( "io" "log" - "net/rpc" + + "github.com/hashicorp/packer/common/net/rpc/codec" + + "github.com/hashicorp/packer/common/net/rpc" "github.com/hashicorp/packer/packer" - "github.com/ugorji/go/codec" ) const ( @@ -125,10 +127,6 @@ func (s *Server) Serve() { } defer stream.Close() - h := &codec.MsgpackHandle{ - RawToString: true, - WriteExt: true, - } - rpcCodec := codec.GoRpc.ServerCodec(stream, h) + rpcCodec := codec.MsgpackServerCodec(stream) s.server.ServeCodec(rpcCodec) } diff --git a/packer/rpc/ui.go b/packer/rpc/ui.go index d86301673..add4970ea 100644 --- a/packer/rpc/ui.go +++ b/packer/rpc/ui.go @@ -1,8 +1,10 @@ package rpc import ( + "context" "log" - "net/rpc" + + "github.com/hashicorp/packer/common/net/rpc" "github.com/hashicorp/packer/packer" ) @@ -30,12 +32,14 @@ type UiMachineArgs struct { } func (u *Ui) Ask(query string) (result string, err error) { - err = u.client.Call("Ui.Ask", query, &result) + ctx := context.TODO() + err = u.client.Call(ctx, "Ui.Ask", query, &result) return } func (u *Ui) Error(message string) { - if err := u.client.Call("Ui.Error", message, new(interface{})); err != nil { + ctx := context.TODO() + if err := u.client.Call(ctx, "Ui.Error", message, new(interface{})); err != nil { log.Printf("Error in Ui.Error RPC call: %s", err) } } @@ -46,19 +50,22 @@ func (u *Ui) Machine(t string, args ...string) { Args: args, } - if err := u.client.Call("Ui.Machine", rpcArgs, new(interface{})); err != nil { + ctx := context.TODO() + if err := u.client.Call(ctx, "Ui.Machine", rpcArgs, new(interface{})); err != nil { log.Printf("Error in Ui.Machine RPC call: %s", err) } } func (u *Ui) Message(message string) { - if err := u.client.Call("Ui.Message", message, new(interface{})); err != nil { + ctx := context.TODO() + if err := u.client.Call(ctx, "Ui.Message", message, new(interface{})); err != nil { log.Printf("Error in Ui.Message RPC call: %s", err) } } func (u *Ui) Say(message string) { - if err := u.client.Call("Ui.Say", message, new(interface{})); err != nil { + ctx := context.TODO() + if err := u.client.Call(ctx, "Ui.Say", message, new(interface{})); err != nil { log.Printf("Error in Ui.Say RPC call: %s", err) } } diff --git a/packer/rpc/ui_progress_tracking.go b/packer/rpc/ui_progress_tracking.go index 4488c1bac..921a0d8aa 100644 --- a/packer/rpc/ui_progress_tracking.go +++ b/packer/rpc/ui_progress_tracking.go @@ -1,9 +1,11 @@ package rpc import ( + "context" "io" "log" - "net/rpc" + + "github.com/hashicorp/packer/common/net/rpc" "github.com/hashicorp/packer/common/random" ) @@ -18,7 +20,8 @@ func (u *Ui) TrackProgress(src string, currentSize, totalSize int64, stream io.R TotalSize: totalSize, } var trackingID string - if err := u.client.Call("Ui.NewTrackProgress", pl, &trackingID); err != nil { + ctx := context.TODO() + if err := u.client.Call(ctx, "Ui.NewTrackProgress", pl, &trackingID); err != nil { log.Printf("Error in Ui.NewTrackProgress RPC call: %s", err) return stream } @@ -39,7 +42,8 @@ type ProgressTrackingClient struct { // Read will send len(b) over the wire instead of it's content func (u *ProgressTrackingClient) Read(b []byte) (read int, err error) { defer func() { - if err := u.client.Call("Ui"+u.id+".Add", read, new(interface{})); err != nil { + ctx := context.TODO() + if err := u.client.Call(ctx, "Ui"+u.id+".Add", read, new(interface{})); err != nil { log.Printf("Error in ProgressTrackingClient.Read RPC call: %s", err) } }() @@ -48,7 +52,8 @@ func (u *ProgressTrackingClient) Read(b []byte) (read int, err error) { func (u *ProgressTrackingClient) Close() error { log.Printf("closing") - if err := u.client.Call("Ui"+u.id+".Close", nil, new(interface{})); err != nil { + ctx := context.TODO() + if err := u.client.Call(ctx, "Ui"+u.id+".Close", nil, new(interface{})); err != nil { log.Printf("Error in ProgressTrackingClient.Close RPC call: %s", err) } return u.stream.Close()