From 77c48678d6a4af459b8adf2c567de652c5f6f9c4 Mon Sep 17 00:00:00 2001 From: Billie Cleek Date: Sat, 19 Dec 2015 13:05:59 -0800 Subject: [PATCH] eliminate possible race conditions Eliminate race-y use of the packer.Ui interface by wrapping it in a concurrency-safe implementation. --- provisioner/ansible/adapter.go | 27 ++---- provisioner/ansible/adapter_test.go | 142 ++++++++++++++++++++++++++++ provisioner/ansible/provisioner.go | 48 +++++++++- 3 files changed, 197 insertions(+), 20 deletions(-) create mode 100644 provisioner/ansible/adapter_test.go diff --git a/provisioner/ansible/adapter.go b/provisioner/ansible/adapter.go index e64939390..418aa3dc7 100644 --- a/provisioner/ansible/adapter.go +++ b/provisioner/ansible/adapter.go @@ -35,16 +35,6 @@ func newAdapter(done <-chan struct{}, l net.Listener, config *ssh.ServerConfig, func (c *adapter) Serve() { c.ui.Say(fmt.Sprintf("SSH proxy: serving on %s", c.l.Addr())) - errc := make(chan error, 1) - - go func(errc chan error) { - for err := range errc { - if err != nil { - c.ui.Error(err.Error()) - } - } - }(errc) - for { // Accept will return if either the underlying connection is closed or if a connection is made. // after returning, check to see if c.done can be received. If so, then Accept() returned because @@ -52,7 +42,6 @@ func (c *adapter) Serve() { conn, err := c.l.Accept() select { case <-c.done: - close(errc) return default: if err != nil { @@ -60,14 +49,16 @@ func (c *adapter) Serve() { continue } go func(conn net.Conn) { - errc <- c.Handle(conn, errc) + if err := c.Handle(conn, c.ui); err != nil { + c.ui.Error(err.Error()) + } }(conn) } } } -func (c *adapter) Handle(conn net.Conn, errc chan<- error) error { - c.ui.Say("SSH proxy: accepted connection") +func (c *adapter) Handle(conn net.Conn, ui packer.Ui) error { + c.ui.Message("SSH proxy: accepted connection") _, chans, reqs, err := ssh.NewServerConn(conn, c.config) if err != nil { return errors.New("failed to handshake") @@ -83,9 +74,11 @@ func (c *adapter) Handle(conn net.Conn, errc chan<- error) error { continue } - go func(errc chan<- error, ch ssh.NewChannel) { - errc <- c.handleSession(ch) - }(errc, newChannel) + go func(ch ssh.NewChannel) { + if err := c.handleSession(ch); err != nil { + c.ui.Error(err.Error()) + } + }(newChannel) } return nil diff --git a/provisioner/ansible/adapter_test.go b/provisioner/ansible/adapter_test.go new file mode 100644 index 000000000..dbe8174c6 --- /dev/null +++ b/provisioner/ansible/adapter_test.go @@ -0,0 +1,142 @@ +package ansible + +import ( + "errors" + "io" + "log" + "net" + "os" + "testing" + "time" + + "github.com/mitchellh/packer/packer" + + "golang.org/x/crypto/ssh" +) + +func TestAdapter_Serve(t *testing.T) { + + // done signals the adapter that the provisioner is done + done := make(chan struct{}) + + acceptC := make(chan struct{}) + l := listener{done: make(chan struct{}), acceptC: acceptC} + + config := &ssh.ServerConfig{} + + ui := new(ui) + + sut := newAdapter(done, &l, config, "", newUi(ui), communicator{}) + go func() { + i := 0 + for range acceptC { + i++ + if i == 4 { + close(done) + l.Close() + } + } + }() + + sut.Serve() +} + +type listener struct { + done chan struct{} + acceptC chan<- struct{} + i int +} + +func (l *listener) Accept() (net.Conn, error) { + log.Println("Accept() called") + l.acceptC <- struct{}{} + select { + case <-l.done: + log.Println("done, serving an error") + return nil, errors.New("listener is closed") + + case <-time.After(10 * time.Millisecond): + l.i++ + + if l.i%2 == 0 { + c1, c2 := net.Pipe() + + go func(c net.Conn) { + <-time.After(100 * time.Millisecond) + log.Println("closing c") + c.Close() + }(c1) + + return c2, nil + } + } + + return nil, errors.New("accept error") +} + +func (l *listener) Close() error { + close(l.done) + return nil +} + +func (l *listener) Addr() net.Addr { + return addr{} +} + +type addr struct{} + +func (a addr) Network() string { + return a.String() +} + +func (a addr) String() string { + return "test" +} + +type ui int + +func (u *ui) Ask(s string) (string, error) { + *u++ + return s, nil +} + +func (u *ui) Say(s string) { + *u++ + log.Println(s) +} + +func (u *ui) Message(s string) { + *u++ + log.Println(s) +} + +func (u *ui) Error(s string) { + *u++ + log.Println(s) +} + +func (u *ui) Machine(s1 string, s2 ...string) { + *u++ + log.Println(s1) + for _, s := range s2 { + log.Println(s) + } +} + +type communicator struct{} + +func (c communicator) Start(*packer.RemoteCmd) error { + return errors.New("communicator not supported") +} + +func (c communicator) Upload(string, io.Reader, *os.FileInfo) error { + return errors.New("communicator not supported") +} + +func (c communicator) UploadDir(dst string, src string, exclude []string) error { + return errors.New("communicator not supported") +} + +func (c communicator) Download(string, io.Writer) error { + return errors.New("communicator not supported") +} diff --git a/provisioner/ansible/provisioner.go b/provisioner/ansible/provisioner.go index acfb961b6..5eaa08c12 100644 --- a/provisioner/ansible/provisioner.go +++ b/provisioner/ansible/provisioner.go @@ -170,6 +170,7 @@ func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { return err } + ui = newUi(ui) p.adapter = newAdapter(p.done, localListener, config, p.config.SFTPCmd, ui, comm) defer func() { @@ -199,12 +200,11 @@ func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { }() } - if err := p.executeAnsible(ui, comm); err != nil { + if err := p.executeAnsible(ui); err != nil { return fmt.Errorf("Error executing Ansible: %s", err) } return nil - } func (p *Provisioner) Cancel() { @@ -217,7 +217,7 @@ func (p *Provisioner) Cancel() { os.Exit(0) } -func (p *Provisioner) executeAnsible(ui packer.Ui, comm packer.Communicator) error { +func (p *Provisioner) executeAnsible(ui packer.Ui) error { playbook, _ := filepath.Abs(p.config.PlaybookFile) inventory := p.config.inventoryFile @@ -275,3 +275,45 @@ func validateFileConfig(name string, config string, req bool) error { } return nil } + +// Ui provides concurrency-safe access to packer.Ui. +type Ui struct { + sem chan int + ui packer.Ui +} + +func newUi(ui packer.Ui) packer.Ui { + return &Ui{sem: make(chan int, 1), ui: ui} +} + +func (ui *Ui) Ask(s string) (string, error) { + ui.sem <- 1 + ret, err := ui.ui.Ask(s) + <-ui.sem + + return ret, err +} + +func (ui *Ui) Say(s string) { + ui.sem <- 1 + ui.ui.Say(s) + <-ui.sem +} + +func (ui *Ui) Message(s string) { + ui.sem <- 1 + ui.ui.Message(s) + <-ui.sem +} + +func (ui *Ui) Error(s string) { + ui.sem <- 1 + ui.ui.Error(s) + <-ui.sem +} + +func (ui *Ui) Machine(t string, args ...string) { + ui.sem <- 1 + ui.ui.Machine(t, args...) + <-ui.sem +}