Merge pull request #5089 from hashicorp/4719

rpc/communicator fix race condition that causes stdout from ssh provi…
This commit is contained in:
Megan Marsh 2017-07-06 13:34:05 -07:00 committed by GitHub
commit 4313f98061
1 changed files with 28 additions and 5 deletions

View File

@ -6,6 +6,7 @@ import (
"log" "log"
"net/rpc" "net/rpc"
"os" "os"
"sync"
"github.com/hashicorp/packer/packer" "github.com/hashicorp/packer/packer"
) )
@ -67,19 +68,33 @@ func (c *communicator) Start(cmd *packer.RemoteCmd) (err error) {
var args CommunicatorStartArgs var args CommunicatorStartArgs
args.Command = cmd.Command args.Command = cmd.Command
var wg sync.WaitGroup
if cmd.Stdin != nil { if cmd.Stdin != nil {
wg.Add(1)
args.StdinStreamId = c.mux.NextId() args.StdinStreamId = c.mux.NextId()
go serveSingleCopy("stdin", c.mux, args.StdinStreamId, nil, cmd.Stdin) go func() {
defer wg.Done()
serveSingleCopy("stdin", c.mux, args.StdinStreamId, nil, cmd.Stdin)
}()
} }
if cmd.Stdout != nil { if cmd.Stdout != nil {
wg.Add(1)
args.StdoutStreamId = c.mux.NextId() args.StdoutStreamId = c.mux.NextId()
go serveSingleCopy("stdout", c.mux, args.StdoutStreamId, cmd.Stdout, nil) go func() {
defer wg.Done()
serveSingleCopy("stdout", c.mux, args.StdoutStreamId, cmd.Stdout, nil)
}()
} }
if cmd.Stderr != nil { if cmd.Stderr != nil {
wg.Add(1)
args.StderrStreamId = c.mux.NextId() args.StderrStreamId = c.mux.NextId()
go serveSingleCopy("stderr", c.mux, args.StderrStreamId, cmd.Stderr, nil) go func() {
defer wg.Done()
serveSingleCopy("stderr", c.mux, args.StderrStreamId, cmd.Stderr, nil)
}()
} }
responseStreamId := c.mux.NextId() responseStreamId := c.mux.NextId()
@ -87,6 +102,7 @@ func (c *communicator) Start(cmd *packer.RemoteCmd) (err error) {
go func() { go func() {
conn, err := c.mux.Accept(responseStreamId) conn, err := c.mux.Accept(responseStreamId)
wg.Wait()
if err != nil { if err != nil {
log.Printf("[ERR] Error accepting response stream %d: %s", log.Printf("[ERR] Error accepting response stream %d: %s",
responseStreamId, err) responseStreamId, err)
@ -97,7 +113,8 @@ func (c *communicator) Start(cmd *packer.RemoteCmd) (err error) {
var finished CommandFinished var finished CommandFinished
decoder := gob.NewDecoder(conn) decoder := gob.NewDecoder(conn)
if err := decoder.Decode(&finished); err != nil { err = decoder.Decode(&finished)
if err != nil {
log.Printf("[ERR] Error decoding response stream %d: %s", log.Printf("[ERR] Error decoding response stream %d: %s",
responseStreamId, err) responseStreamId, err)
cmd.SetExited(123) cmd.SetExited(123)
@ -115,7 +132,12 @@ func (c *communicator) Start(cmd *packer.RemoteCmd) (err error) {
func (c *communicator) Upload(path string, r io.Reader, fi *os.FileInfo) (err error) { func (c *communicator) Upload(path string, r io.Reader, fi *os.FileInfo) (err error) {
// Pipe the reader through to the connection // Pipe the reader through to the connection
streamId := c.mux.NextId() streamId := c.mux.NextId()
go serveSingleCopy("uploadData", c.mux, streamId, nil, r) var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
serveSingleCopy("uploadData", c.mux, streamId, nil, r)
}()
args := CommunicatorUploadArgs{ args := CommunicatorUploadArgs{
Path: path, Path: path,
@ -127,6 +149,7 @@ func (c *communicator) Upload(path string, r io.Reader, fi *os.FileInfo) (err er
} }
err = c.client.Call("Communicator.Upload", &args, new(interface{})) err = c.client.Call("Communicator.Upload", &args, new(interface{}))
wg.Wait()
return return
} }