From 03850cafc646600d00b2e038ef81286f9e2fe392 Mon Sep 17 00:00:00 2001 From: Chris Bednarski Date: Thu, 2 Jul 2015 03:40:47 -0700 Subject: [PATCH] Implemented timeout around the SSH handshake, including a unit test --- communicator/ssh/communicator.go | 46 +++++++++++++++-- communicator/ssh/communicator_test.go | 72 ++++++++++++++++++++++++--- 2 files changed, 107 insertions(+), 11 deletions(-) diff --git a/communicator/ssh/communicator.go b/communicator/ssh/communicator.go index cc61e8e9f..f05f6e46e 100644 --- a/communicator/ssh/communicator.go +++ b/communicator/ssh/communicator.go @@ -5,9 +5,6 @@ import ( "bytes" "errors" "fmt" - "github.com/mitchellh/packer/packer" - "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/agent" "io" "io/ioutil" "log" @@ -16,8 +13,15 @@ import ( "path/filepath" "strconv" "sync" + "time" + + "github.com/mitchellh/packer/packer" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" ) +var ErrHandshakeTimeout = fmt.Errorf("Timeout during SSH handshake") + type comm struct { client *ssh.Client config *Config @@ -40,6 +44,10 @@ type Config struct { // DisableAgent, if true, will not forward the SSH agent. DisableAgent bool + + // HandshakeTimeout limits the amount of time we'll wait to handshake before + // saying the connection failed. + HandshakeTimeout time.Duration } // Creates a new packer.Communicator implementation over SSH. This takes @@ -273,9 +281,39 @@ func (c *comm) reconnect() (err error) { } log.Printf("handshaking with SSH") - sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig) + + // Default timeout to 1 minute if it wasn't specified (zero value). For + // when you need to handshake from low orbit. + var duration time.Duration + if c.config.HandshakeTimeout == 0 { + duration = 1 * time.Minute + } else { + duration = c.config.HandshakeTimeout + } + + timeoutExceeded := time.After(duration) + connectionEstablished := make(chan bool, 1) + + var sshConn ssh.Conn + var sshChan <-chan ssh.NewChannel + var req <-chan *ssh.Request + + go func() { + sshConn, sshChan, req, err = ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig) + connectionEstablished <- true + }() + + select { + case <-connectionEstablished: + // We don't need to do anything here. We just want select to block until + // we connect or timeout. + case <-timeoutExceeded: + return ErrHandshakeTimeout + } + if err != nil { log.Printf("handshake error: %s", err) + return } log.Printf("handshake complete!") if sshConn != nil { diff --git a/communicator/ssh/communicator_test.go b/communicator/ssh/communicator_test.go index e9f73d2dc..6398bd713 100644 --- a/communicator/ssh/communicator_test.go +++ b/communicator/ssh/communicator_test.go @@ -5,10 +5,12 @@ package ssh import ( "bytes" "fmt" - "github.com/mitchellh/packer/packer" - "golang.org/x/crypto/ssh" "net" "testing" + "time" + + "github.com/mitchellh/packer/packer" + "golang.org/x/crypto/ssh" ) // private key for mock server @@ -94,6 +96,28 @@ func newMockLineServer(t *testing.T) string { return l.Addr().String() } +func newMockBrokenServer(t *testing.T) string { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Unable tp listen for connection: %s", err) + } + + go func() { + defer l.Close() + c, err := l.Accept() + if err != nil { + t.Errorf("Unable to accept incoming connection: %s", err) + } + defer c.Close() + // This should block for a period of time longer than our timeout in + // the test case. That way we invoke a failure scenario. + time.Sleep(5 * time.Second) + t.Log("Block on handshaking for SSH connection") + }() + + return l.Addr().String() +} + func TestCommIsCommunicator(t *testing.T) { var raw interface{} raw = &comm{} @@ -157,10 +181,44 @@ func TestStart(t *testing.T) { t.Fatalf("error connecting to SSH: %s", err) } - var cmd packer.RemoteCmd - stdout := new(bytes.Buffer) - cmd.Command = "echo foo" - cmd.Stdout = stdout + cmd := &packer.RemoteCmd{ + Command: "echo foo", + Stdout: new(bytes.Buffer), + } - client.Start(&cmd) + client.Start(cmd) +} + +func TestHandshakeTimeout(t *testing.T) { + clientConfig := &ssh.ClientConfig{ + User: "user", + Auth: []ssh.AuthMethod{ + ssh.Password("pass"), + }, + } + + address := newMockBrokenServer(t) + conn := func() (net.Conn, error) { + conn, err := net.Dial("tcp", address) + if err != nil { + t.Fatalf("unable to dial to remote side: %s", err) + } + return conn, err + } + + config := &Config{ + Connection: conn, + SSHConfig: clientConfig, + HandshakeTimeout: 50 * time.Millisecond, + } + + _, err := New(address, config) + if err != ErrHandshakeTimeout { + // Note: there's another error that can come back from this call: + // ssh: handshake failed: EOF + // This should appear in cases where the handshake fails because of + // malformed (or no) data sent back by the server, but should not happen + // in a timeout scenario. + t.Fatalf("Expected handshake timeout, got: %s", err) + } }