eliminate possible race conditions
Eliminate race-y use of the packer.Ui interface by wrapping it in a concurrency-safe implementation.
This commit is contained in:
parent
d73e75a7cf
commit
77c48678d6
|
@ -35,16 +35,6 @@ func newAdapter(done <-chan struct{}, l net.Listener, config *ssh.ServerConfig,
|
|||
func (c *adapter) Serve() {
|
||||
c.ui.Say(fmt.Sprintf("SSH proxy: serving on %s", c.l.Addr()))
|
||||
|
||||
errc := make(chan error, 1)
|
||||
|
||||
go func(errc chan error) {
|
||||
for err := range errc {
|
||||
if err != nil {
|
||||
c.ui.Error(err.Error())
|
||||
}
|
||||
}
|
||||
}(errc)
|
||||
|
||||
for {
|
||||
// Accept will return if either the underlying connection is closed or if a connection is made.
|
||||
// after returning, check to see if c.done can be received. If so, then Accept() returned because
|
||||
|
@ -52,7 +42,6 @@ func (c *adapter) Serve() {
|
|||
conn, err := c.l.Accept()
|
||||
select {
|
||||
case <-c.done:
|
||||
close(errc)
|
||||
return
|
||||
default:
|
||||
if err != nil {
|
||||
|
@ -60,14 +49,16 @@ func (c *adapter) Serve() {
|
|||
continue
|
||||
}
|
||||
go func(conn net.Conn) {
|
||||
errc <- c.Handle(conn, errc)
|
||||
if err := c.Handle(conn, c.ui); err != nil {
|
||||
c.ui.Error(err.Error())
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *adapter) Handle(conn net.Conn, errc chan<- error) error {
|
||||
c.ui.Say("SSH proxy: accepted connection")
|
||||
func (c *adapter) Handle(conn net.Conn, ui packer.Ui) error {
|
||||
c.ui.Message("SSH proxy: accepted connection")
|
||||
_, chans, reqs, err := ssh.NewServerConn(conn, c.config)
|
||||
if err != nil {
|
||||
return errors.New("failed to handshake")
|
||||
|
@ -83,9 +74,11 @@ func (c *adapter) Handle(conn net.Conn, errc chan<- error) error {
|
|||
continue
|
||||
}
|
||||
|
||||
go func(errc chan<- error, ch ssh.NewChannel) {
|
||||
errc <- c.handleSession(ch)
|
||||
}(errc, newChannel)
|
||||
go func(ch ssh.NewChannel) {
|
||||
if err := c.handleSession(ch); err != nil {
|
||||
c.ui.Error(err.Error())
|
||||
}
|
||||
}(newChannel)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
package ansible
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mitchellh/packer/packer"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestAdapter_Serve(t *testing.T) {
|
||||
|
||||
// done signals the adapter that the provisioner is done
|
||||
done := make(chan struct{})
|
||||
|
||||
acceptC := make(chan struct{})
|
||||
l := listener{done: make(chan struct{}), acceptC: acceptC}
|
||||
|
||||
config := &ssh.ServerConfig{}
|
||||
|
||||
ui := new(ui)
|
||||
|
||||
sut := newAdapter(done, &l, config, "", newUi(ui), communicator{})
|
||||
go func() {
|
||||
i := 0
|
||||
for range acceptC {
|
||||
i++
|
||||
if i == 4 {
|
||||
close(done)
|
||||
l.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
sut.Serve()
|
||||
}
|
||||
|
||||
type listener struct {
|
||||
done chan struct{}
|
||||
acceptC chan<- struct{}
|
||||
i int
|
||||
}
|
||||
|
||||
func (l *listener) Accept() (net.Conn, error) {
|
||||
log.Println("Accept() called")
|
||||
l.acceptC <- struct{}{}
|
||||
select {
|
||||
case <-l.done:
|
||||
log.Println("done, serving an error")
|
||||
return nil, errors.New("listener is closed")
|
||||
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
l.i++
|
||||
|
||||
if l.i%2 == 0 {
|
||||
c1, c2 := net.Pipe()
|
||||
|
||||
go func(c net.Conn) {
|
||||
<-time.After(100 * time.Millisecond)
|
||||
log.Println("closing c")
|
||||
c.Close()
|
||||
}(c1)
|
||||
|
||||
return c2, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("accept error")
|
||||
}
|
||||
|
||||
func (l *listener) Close() error {
|
||||
close(l.done)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *listener) Addr() net.Addr {
|
||||
return addr{}
|
||||
}
|
||||
|
||||
type addr struct{}
|
||||
|
||||
func (a addr) Network() string {
|
||||
return a.String()
|
||||
}
|
||||
|
||||
func (a addr) String() string {
|
||||
return "test"
|
||||
}
|
||||
|
||||
type ui int
|
||||
|
||||
func (u *ui) Ask(s string) (string, error) {
|
||||
*u++
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (u *ui) Say(s string) {
|
||||
*u++
|
||||
log.Println(s)
|
||||
}
|
||||
|
||||
func (u *ui) Message(s string) {
|
||||
*u++
|
||||
log.Println(s)
|
||||
}
|
||||
|
||||
func (u *ui) Error(s string) {
|
||||
*u++
|
||||
log.Println(s)
|
||||
}
|
||||
|
||||
func (u *ui) Machine(s1 string, s2 ...string) {
|
||||
*u++
|
||||
log.Println(s1)
|
||||
for _, s := range s2 {
|
||||
log.Println(s)
|
||||
}
|
||||
}
|
||||
|
||||
type communicator struct{}
|
||||
|
||||
func (c communicator) Start(*packer.RemoteCmd) error {
|
||||
return errors.New("communicator not supported")
|
||||
}
|
||||
|
||||
func (c communicator) Upload(string, io.Reader, *os.FileInfo) error {
|
||||
return errors.New("communicator not supported")
|
||||
}
|
||||
|
||||
func (c communicator) UploadDir(dst string, src string, exclude []string) error {
|
||||
return errors.New("communicator not supported")
|
||||
}
|
||||
|
||||
func (c communicator) Download(string, io.Writer) error {
|
||||
return errors.New("communicator not supported")
|
||||
}
|
|
@ -170,6 +170,7 @@ func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error {
|
|||
return err
|
||||
}
|
||||
|
||||
ui = newUi(ui)
|
||||
p.adapter = newAdapter(p.done, localListener, config, p.config.SFTPCmd, ui, comm)
|
||||
|
||||
defer func() {
|
||||
|
@ -199,12 +200,11 @@ func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error {
|
|||
}()
|
||||
}
|
||||
|
||||
if err := p.executeAnsible(ui, comm); err != nil {
|
||||
if err := p.executeAnsible(ui); err != nil {
|
||||
return fmt.Errorf("Error executing Ansible: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (p *Provisioner) Cancel() {
|
||||
|
@ -217,7 +217,7 @@ func (p *Provisioner) Cancel() {
|
|||
os.Exit(0)
|
||||
}
|
||||
|
||||
func (p *Provisioner) executeAnsible(ui packer.Ui, comm packer.Communicator) error {
|
||||
func (p *Provisioner) executeAnsible(ui packer.Ui) error {
|
||||
playbook, _ := filepath.Abs(p.config.PlaybookFile)
|
||||
inventory := p.config.inventoryFile
|
||||
|
||||
|
@ -275,3 +275,45 @@ func validateFileConfig(name string, config string, req bool) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ui provides concurrency-safe access to packer.Ui.
|
||||
type Ui struct {
|
||||
sem chan int
|
||||
ui packer.Ui
|
||||
}
|
||||
|
||||
func newUi(ui packer.Ui) packer.Ui {
|
||||
return &Ui{sem: make(chan int, 1), ui: ui}
|
||||
}
|
||||
|
||||
func (ui *Ui) Ask(s string) (string, error) {
|
||||
ui.sem <- 1
|
||||
ret, err := ui.ui.Ask(s)
|
||||
<-ui.sem
|
||||
|
||||
return ret, err
|
||||
}
|
||||
|
||||
func (ui *Ui) Say(s string) {
|
||||
ui.sem <- 1
|
||||
ui.ui.Say(s)
|
||||
<-ui.sem
|
||||
}
|
||||
|
||||
func (ui *Ui) Message(s string) {
|
||||
ui.sem <- 1
|
||||
ui.ui.Message(s)
|
||||
<-ui.sem
|
||||
}
|
||||
|
||||
func (ui *Ui) Error(s string) {
|
||||
ui.sem <- 1
|
||||
ui.ui.Error(s)
|
||||
<-ui.sem
|
||||
}
|
||||
|
||||
func (ui *Ui) Machine(t string, args ...string) {
|
||||
ui.sem <- 1
|
||||
ui.ui.Machine(t, args...)
|
||||
<-ui.sem
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue