Implemented timeout around the SSH handshake, including a unit test

This commit is contained in:
Chris Bednarski 2015-07-02 03:40:47 -07:00
parent 6ca48fa3c8
commit 03850cafc6
2 changed files with 107 additions and 11 deletions

View File

@ -5,9 +5,6 @@ import (
"bytes"
"errors"
"fmt"
"github.com/mitchellh/packer/packer"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"io"
"io/ioutil"
"log"
@ -16,8 +13,15 @@ import (
"path/filepath"
"strconv"
"sync"
"time"
"github.com/mitchellh/packer/packer"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
var ErrHandshakeTimeout = fmt.Errorf("Timeout during SSH handshake")
type comm struct {
client *ssh.Client
config *Config
@ -40,6 +44,10 @@ type Config struct {
// DisableAgent, if true, will not forward the SSH agent.
DisableAgent bool
// HandshakeTimeout limits the amount of time we'll wait to handshake before
// saying the connection failed.
HandshakeTimeout time.Duration
}
// Creates a new packer.Communicator implementation over SSH. This takes
@ -273,9 +281,39 @@ func (c *comm) reconnect() (err error) {
}
log.Printf("handshaking with SSH")
sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig)
// Default timeout to 1 minute if it wasn't specified (zero value). For
// when you need to handshake from low orbit.
var duration time.Duration
if c.config.HandshakeTimeout == 0 {
duration = 1 * time.Minute
} else {
duration = c.config.HandshakeTimeout
}
timeoutExceeded := time.After(duration)
connectionEstablished := make(chan bool, 1)
var sshConn ssh.Conn
var sshChan <-chan ssh.NewChannel
var req <-chan *ssh.Request
go func() {
sshConn, sshChan, req, err = ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig)
connectionEstablished <- true
}()
select {
case <-connectionEstablished:
// We don't need to do anything here. We just want select to block until
// we connect or timeout.
case <-timeoutExceeded:
return ErrHandshakeTimeout
}
if err != nil {
log.Printf("handshake error: %s", err)
return
}
log.Printf("handshake complete!")
if sshConn != nil {

View File

@ -5,10 +5,12 @@ package ssh
import (
"bytes"
"fmt"
"github.com/mitchellh/packer/packer"
"golang.org/x/crypto/ssh"
"net"
"testing"
"time"
"github.com/mitchellh/packer/packer"
"golang.org/x/crypto/ssh"
)
// private key for mock server
@ -94,6 +96,28 @@ func newMockLineServer(t *testing.T) string {
return l.Addr().String()
}
func newMockBrokenServer(t *testing.T) string {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Unable tp listen for connection: %s", err)
}
go func() {
defer l.Close()
c, err := l.Accept()
if err != nil {
t.Errorf("Unable to accept incoming connection: %s", err)
}
defer c.Close()
// This should block for a period of time longer than our timeout in
// the test case. That way we invoke a failure scenario.
time.Sleep(5 * time.Second)
t.Log("Block on handshaking for SSH connection")
}()
return l.Addr().String()
}
func TestCommIsCommunicator(t *testing.T) {
var raw interface{}
raw = &comm{}
@ -157,10 +181,44 @@ func TestStart(t *testing.T) {
t.Fatalf("error connecting to SSH: %s", err)
}
var cmd packer.RemoteCmd
stdout := new(bytes.Buffer)
cmd.Command = "echo foo"
cmd.Stdout = stdout
cmd := &packer.RemoteCmd{
Command: "echo foo",
Stdout: new(bytes.Buffer),
}
client.Start(&cmd)
client.Start(cmd)
}
func TestHandshakeTimeout(t *testing.T) {
clientConfig := &ssh.ClientConfig{
User: "user",
Auth: []ssh.AuthMethod{
ssh.Password("pass"),
},
}
address := newMockBrokenServer(t)
conn := func() (net.Conn, error) {
conn, err := net.Dial("tcp", address)
if err != nil {
t.Fatalf("unable to dial to remote side: %s", err)
}
return conn, err
}
config := &Config{
Connection: conn,
SSHConfig: clientConfig,
HandshakeTimeout: 50 * time.Millisecond,
}
_, err := New(address, config)
if err != ErrHandshakeTimeout {
// Note: there's another error that can come back from this call:
// ssh: handshake failed: EOF
// This should appear in cases where the handshake fails because of
// malformed (or no) data sent back by the server, but should not happen
// in a timeout scenario.
t.Fatalf("Expected handshake timeout, got: %s", err)
}
}