package rpc import ( "encoding/gob" "errors" "github.com/mitchellh/packer/packer" "io" "log" "net" "net/rpc" ) // An implementation of packer.Communicator where the communicator is actually // executed over an RPC connection. type communicator struct { client *rpc.Client } // CommunicatorServer wraps a packer.Communicator implementation and makes // it exportable as part of a Golang RPC server. type CommunicatorServer struct { c packer.Communicator } type CommandFinished struct { ExitStatus int } type CommunicatorStartArgs struct { Command string StdinAddress string StdoutAddress string StderrAddress string ResponseAddress string } type CommunicatorDownloadArgs struct { Path string WriterAddress string } type CommunicatorUploadArgs struct { Path string ReaderAddress string } type CommunicatorUploadDirArgs struct { Dst string Src string Exclude []string } func Communicator(client *rpc.Client) *communicator { return &communicator{client} } func (c *communicator) Start(cmd *packer.RemoteCmd) (err error) { var args CommunicatorStartArgs args.Command = cmd.Command if cmd.Stdin != nil { stdinL := netListenerInRange(portRangeMin, portRangeMax) args.StdinAddress = stdinL.Addr().String() go serveSingleCopy("stdin", stdinL, nil, cmd.Stdin) } if cmd.Stdout != nil { stdoutL := netListenerInRange(portRangeMin, portRangeMax) args.StdoutAddress = stdoutL.Addr().String() go serveSingleCopy("stdout", stdoutL, cmd.Stdout, nil) } if cmd.Stderr != nil { stderrL := netListenerInRange(portRangeMin, portRangeMax) args.StderrAddress = stderrL.Addr().String() go serveSingleCopy("stderr", stderrL, cmd.Stderr, nil) } responseL := netListenerInRange(portRangeMin, portRangeMax) args.ResponseAddress = responseL.Addr().String() go func() { defer responseL.Close() conn, err := responseL.Accept() if err != nil { cmd.SetExited(123) return } defer conn.Close() decoder := gob.NewDecoder(conn) var finished CommandFinished if err := decoder.Decode(&finished); err != nil { cmd.SetExited(123) return } cmd.SetExited(finished.ExitStatus) }() err = c.client.Call("Communicator.Start", &args, new(interface{})) return } func (c *communicator) Upload(path string, r io.Reader) (err error) { // We need to create a server that can proxy the reader data // over because we can't simply gob encode an io.Reader readerL := netListenerInRange(portRangeMin, portRangeMax) if readerL == nil { err = errors.New("couldn't allocate listener for upload reader") return } // Make sure at the end of this call, we close the listener defer readerL.Close() // Pipe the reader through to the connection go serveSingleCopy("uploadReader", readerL, nil, r) args := CommunicatorUploadArgs{ path, readerL.Addr().String(), } err = c.client.Call("Communicator.Upload", &args, new(interface{})) return } func (c *communicator) UploadDir(dst string, src string, exclude []string) error { args := &CommunicatorUploadDirArgs{ Dst: dst, Src: src, Exclude: exclude, } var reply error err := c.client.Call("Communicator.UploadDir", args, &reply) if err == nil { err = reply } return err } func (c *communicator) Download(path string, w io.Writer) (err error) { // We need to create a server that can proxy that data downloaded // into the writer because we can't gob encode a writer directly. writerL := netListenerInRange(portRangeMin, portRangeMax) if writerL == nil { err = errors.New("couldn't allocate listener for download writer") return } // Make sure we close the listener once we're done because we'll be done defer writerL.Close() // Serve a single connection and a single copy go serveSingleCopy("downloadWriter", writerL, w, nil) args := CommunicatorDownloadArgs{ path, writerL.Addr().String(), } err = c.client.Call("Communicator.Download", &args, new(interface{})) return } func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface{}) (err error) { // Build the RemoteCmd on this side so that it all pipes over // to the remote side. var cmd packer.RemoteCmd cmd.Command = args.Command toClose := make([]net.Conn, 0) if args.StdinAddress != "" { stdinC, err := tcpDial(args.StdinAddress) if err != nil { return err } toClose = append(toClose, stdinC) cmd.Stdin = stdinC } if args.StdoutAddress != "" { stdoutC, err := tcpDial(args.StdoutAddress) if err != nil { return err } toClose = append(toClose, stdoutC) cmd.Stdout = stdoutC } if args.StderrAddress != "" { stderrC, err := tcpDial(args.StderrAddress) if err != nil { return err } toClose = append(toClose, stderrC) cmd.Stderr = stderrC } // Connect to the response address so we can write our result to it // when ready. responseC, err := tcpDial(args.ResponseAddress) if err != nil { return err } responseWriter := gob.NewEncoder(responseC) // Start the actual command err = c.c.Start(&cmd) // Start a goroutine to spin and wait for the process to actual // exit. When it does, report it back to caller... go func() { defer responseC.Close() for _, conn := range toClose { defer conn.Close() } cmd.Wait() responseWriter.Encode(&CommandFinished{cmd.ExitStatus}) }() return } func (c *CommunicatorServer) Upload(args *CommunicatorUploadArgs, reply *interface{}) (err error) { readerC, err := tcpDial(args.ReaderAddress) if err != nil { return } defer readerC.Close() err = c.c.Upload(args.Path, readerC) return } func (c *CommunicatorServer) UploadDir(args *CommunicatorUploadDirArgs, reply *error) error { return c.c.UploadDir(args.Dst, args.Src, args.Exclude) } func (c *CommunicatorServer) Download(args *CommunicatorDownloadArgs, reply *interface{}) (err error) { writerC, err := tcpDial(args.WriterAddress) if err != nil { return } defer writerC.Close() err = c.c.Download(args.Path, writerC) return } func serveSingleCopy(name string, l net.Listener, dst io.Writer, src io.Reader) { defer l.Close() conn, err := l.Accept() if err != nil { log.Printf("'%s' accept error: %s", name, err) return } // Be sure to close the connection after we're done copying so // that an EOF will successfully be sent to the remote side defer conn.Close() // The connection is the destination/source that is nil if dst == nil { dst = conn } else { src = conn } written, err := io.Copy(dst, src) log.Printf("%d bytes written for '%s'", written, name) if err != nil { log.Printf("'%s' copy error: %s", name, err) } }