Implemented timeout around the SSH handshake, including a unit test
This commit is contained in:
parent
6ca48fa3c8
commit
03850cafc6
|
@ -5,9 +5,6 @@ import (
|
|||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/mitchellh/packer/packer"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
|
@ -16,8 +13,15 @@ import (
|
|||
"path/filepath"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mitchellh/packer/packer"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
)
|
||||
|
||||
var ErrHandshakeTimeout = fmt.Errorf("Timeout during SSH handshake")
|
||||
|
||||
type comm struct {
|
||||
client *ssh.Client
|
||||
config *Config
|
||||
|
@ -40,6 +44,10 @@ type Config struct {
|
|||
|
||||
// DisableAgent, if true, will not forward the SSH agent.
|
||||
DisableAgent bool
|
||||
|
||||
// HandshakeTimeout limits the amount of time we'll wait to handshake before
|
||||
// saying the connection failed.
|
||||
HandshakeTimeout time.Duration
|
||||
}
|
||||
|
||||
// Creates a new packer.Communicator implementation over SSH. This takes
|
||||
|
@ -273,9 +281,39 @@ func (c *comm) reconnect() (err error) {
|
|||
}
|
||||
|
||||
log.Printf("handshaking with SSH")
|
||||
sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig)
|
||||
|
||||
// Default timeout to 1 minute if it wasn't specified (zero value). For
|
||||
// when you need to handshake from low orbit.
|
||||
var duration time.Duration
|
||||
if c.config.HandshakeTimeout == 0 {
|
||||
duration = 1 * time.Minute
|
||||
} else {
|
||||
duration = c.config.HandshakeTimeout
|
||||
}
|
||||
|
||||
timeoutExceeded := time.After(duration)
|
||||
connectionEstablished := make(chan bool, 1)
|
||||
|
||||
var sshConn ssh.Conn
|
||||
var sshChan <-chan ssh.NewChannel
|
||||
var req <-chan *ssh.Request
|
||||
|
||||
go func() {
|
||||
sshConn, sshChan, req, err = ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig)
|
||||
connectionEstablished <- true
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-connectionEstablished:
|
||||
// We don't need to do anything here. We just want select to block until
|
||||
// we connect or timeout.
|
||||
case <-timeoutExceeded:
|
||||
return ErrHandshakeTimeout
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Printf("handshake error: %s", err)
|
||||
return
|
||||
}
|
||||
log.Printf("handshake complete!")
|
||||
if sshConn != nil {
|
||||
|
|
|
@ -5,10 +5,12 @@ package ssh
|
|||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/mitchellh/packer/packer"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mitchellh/packer/packer"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// private key for mock server
|
||||
|
@ -94,6 +96,28 @@ func newMockLineServer(t *testing.T) string {
|
|||
return l.Addr().String()
|
||||
}
|
||||
|
||||
func newMockBrokenServer(t *testing.T) string {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Unable tp listen for connection: %s", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer l.Close()
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
t.Errorf("Unable to accept incoming connection: %s", err)
|
||||
}
|
||||
defer c.Close()
|
||||
// This should block for a period of time longer than our timeout in
|
||||
// the test case. That way we invoke a failure scenario.
|
||||
time.Sleep(5 * time.Second)
|
||||
t.Log("Block on handshaking for SSH connection")
|
||||
}()
|
||||
|
||||
return l.Addr().String()
|
||||
}
|
||||
|
||||
func TestCommIsCommunicator(t *testing.T) {
|
||||
var raw interface{}
|
||||
raw = &comm{}
|
||||
|
@ -157,10 +181,44 @@ func TestStart(t *testing.T) {
|
|||
t.Fatalf("error connecting to SSH: %s", err)
|
||||
}
|
||||
|
||||
var cmd packer.RemoteCmd
|
||||
stdout := new(bytes.Buffer)
|
||||
cmd.Command = "echo foo"
|
||||
cmd.Stdout = stdout
|
||||
cmd := &packer.RemoteCmd{
|
||||
Command: "echo foo",
|
||||
Stdout: new(bytes.Buffer),
|
||||
}
|
||||
|
||||
client.Start(&cmd)
|
||||
client.Start(cmd)
|
||||
}
|
||||
|
||||
func TestHandshakeTimeout(t *testing.T) {
|
||||
clientConfig := &ssh.ClientConfig{
|
||||
User: "user",
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.Password("pass"),
|
||||
},
|
||||
}
|
||||
|
||||
address := newMockBrokenServer(t)
|
||||
conn := func() (net.Conn, error) {
|
||||
conn, err := net.Dial("tcp", address)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to dial to remote side: %s", err)
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
Connection: conn,
|
||||
SSHConfig: clientConfig,
|
||||
HandshakeTimeout: 50 * time.Millisecond,
|
||||
}
|
||||
|
||||
_, err := New(address, config)
|
||||
if err != ErrHandshakeTimeout {
|
||||
// Note: there's another error that can come back from this call:
|
||||
// ssh: handshake failed: EOF
|
||||
// This should appear in cases where the handshake fails because of
|
||||
// malformed (or no) data sent back by the server, but should not happen
|
||||
// in a timeout scenario.
|
||||
t.Fatalf("Expected handshake timeout, got: %s", err)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue