communicator/ssh: have a Connection func so we can re-establish

[GH-152]
This commit is contained in:
Mitchell Hashimoto 2013-07-14 20:22:41 +09:00
parent db644c91fb
commit 9718a4656c
7 changed files with 173 additions and 110 deletions

View File

@ -9,13 +9,12 @@ import (
"github.com/mitchellh/packer/communicator/ssh"
"github.com/mitchellh/packer/packer"
"log"
"net"
"time"
)
type stepConnectSSH struct {
cancel bool
conn net.Conn
comm packer.Communicator
}
func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction {
@ -45,6 +44,7 @@ WaitLoop:
return multistep.ActionHalt
}
s.comm = comm
state["communicator"] = comm
break WaitLoop
case <-timeout:
@ -63,9 +63,9 @@ WaitLoop:
}
func (s *stepConnectSSH) Cleanup(map[string]interface{}) {
if s.conn != nil {
s.conn.Close()
s.conn = nil
if s.comm != nil {
// Close it TODO
s.comm = nil
}
}
@ -85,14 +85,13 @@ func (s *stepConnectSSH) waitForSSH(state map[string]interface{}) (packer.Commun
return nil, fmt.Errorf("Error setting up SSH config: %s", err)
}
// Create the function that will be used to create the connection
connFunc := ssh.ConnectFunc(
"tcp", fmt.Sprintf("%s:%d", instance.DNSName, config.SSHPort))
ui.Say("Waiting for SSH to become available...")
var comm packer.Communicator
var nc net.Conn
for {
if nc != nil {
nc.Close()
}
time.Sleep(5 * time.Second)
if s.cancel {
@ -100,28 +99,29 @@ func (s *stepConnectSSH) waitForSSH(state map[string]interface{}) (packer.Commun
return nil, errors.New("SSH wait cancelled")
}
// Attempt to connect to SSH port
log.Printf(
"Opening TCP conn for SSH to %s:%d",
instance.DNSName, config.SSHPort)
nc, err := net.Dial("tcp",
fmt.Sprintf("%s:%d", instance.DNSName, config.SSHPort))
// First just attempt a normal TCP connection that we close right
// away. We just test this in order to wait for the TCP port to be ready.
nc, err := connFunc()
if err != nil {
log.Printf("TCP connection to SSH ip/port failed: %s", err)
continue
}
nc.Close()
// Build the actual SSH client configuration
sshConfig := &gossh.ClientConfig{
User: config.SSHUsername,
Auth: []gossh.ClientAuth{
gossh.ClientAuthKeyring(keyring),
// Build the configuration to connect to SSH
config := &ssh.Config{
Connection: connFunc,
SSHConfig: &gossh.ClientConfig{
User: config.SSHUsername,
Auth: []gossh.ClientAuth{
gossh.ClientAuthKeyring(keyring),
},
},
}
sshConnectSuccess := make(chan bool, 1)
go func() {
comm, err = ssh.New(nc, sshConfig)
comm, err = ssh.New(config)
if err != nil {
log.Printf("SSH connection fail: %s", err)
sshConnectSuccess <- false
@ -145,7 +145,5 @@ func (s *stepConnectSSH) waitForSSH(state map[string]interface{}) (packer.Commun
break
}
// Store the connection so we can close it later
s.conn = nc
return comm, nil
}

View File

@ -8,12 +8,11 @@ import (
"github.com/mitchellh/packer/communicator/ssh"
"github.com/mitchellh/packer/packer"
"log"
"net"
"time"
)
type stepConnectSSH struct {
conn net.Conn
comm packer.Communicator
}
func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction {
@ -33,11 +32,16 @@ func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction
return multistep.ActionHalt
}
connFunc := ssh.ConnectFunc("tcp", fmt.Sprintf("%s:%d", ipAddress, config.SSHPort))
// Build the actual SSH client configuration
sshConfig := &gossh.ClientConfig{
User: config.SSHUsername,
Auth: []gossh.ClientAuth{
gossh.ClientAuthKeyring(keyring),
sshConfig := &ssh.Config{
Connection: connFunc,
SSHConfig: &gossh.ClientConfig{
User: config.SSHUsername,
Auth: []gossh.ClientAuth{
gossh.ClientAuthKeyring(keyring),
},
},
}
@ -50,8 +54,6 @@ func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction
var comm packer.Communicator
go func() {
var err error
ui.Say("Connecting to the droplet via SSH...")
attempts := 0
handshakeAttempts := 0
@ -62,34 +64,31 @@ func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction
default:
}
attempts += 1
log.Printf(
"Opening TCP conn for SSH to %s:%d (attempt %d)",
ipAddress, config.SSHPort, attempts)
s.conn, err = net.DialTimeout(
"tcp",
fmt.Sprintf("%s:%d", ipAddress, config.SSHPort),
10*time.Second)
if err == nil {
log.Println("TCP connection made. Attempting SSH handshake.")
comm, err = ssh.New(s.conn, sshConfig)
if err == nil {
log.Println("Connected to SSH!")
break
}
handshakeAttempts += 1
log.Printf("SSH handshake error: %s", err)
if handshakeAttempts > 5 {
connected <- err
return
}
}
// A brief sleep so we're not being overly zealous attempting
// to connect to the instance.
time.Sleep(500 * time.Millisecond)
attempts += 1
nc, err := connFunc()
if err != nil {
continue
}
nc.Close()
log.Println("TCP connection made. Attempting SSH handshake.")
comm, err = ssh.New(sshConfig)
if err == nil {
log.Println("Connected to SSH!")
break
}
handshakeAttempts += 1
log.Printf("SSH handshake error: %s", err)
if handshakeAttempts > 5 {
connected <- err
return
}
}
connected <- nil
@ -125,13 +124,15 @@ ConnectWaitLoop:
}
// Set the communicator on the state bag so it can be used later
s.comm = comm
state["communicator"] = comm
return multistep.ActionContinue
}
func (s *stepConnectSSH) Cleanup(map[string]interface{}) {
if s.conn != nil {
s.conn.Close()
if s.comm != nil {
// TODO: close
s.comm = nil
}
}

View File

@ -8,7 +8,6 @@ import (
"github.com/mitchellh/packer/communicator/ssh"
"github.com/mitchellh/packer/packer"
"log"
"net"
"time"
)
@ -24,7 +23,7 @@ import (
// communicator packer.Communicator
type stepWaitForSSH struct {
cancel bool
conn net.Conn
comm packer.Communicator
}
func (s *stepWaitForSSH) Run(state map[string]interface{}) multistep.StepAction {
@ -54,6 +53,7 @@ WaitLoop:
return multistep.ActionHalt
}
s.comm = comm
state["communicator"] = comm
break WaitLoop
case <-timeout:
@ -72,9 +72,9 @@ WaitLoop:
}
func (s *stepWaitForSSH) Cleanup(map[string]interface{}) {
if s.conn != nil {
s.conn.Close()
s.conn = nil
if s.comm != nil {
// TODO: close
s.comm = nil
}
}
@ -85,14 +85,11 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun
ui := state["ui"].(packer.Ui)
sshHostPort := state["sshHostPort"].(uint)
connFunc := ssh.ConnectFunc("tcp", fmt.Sprintf("127.0.0.1:%d", sshHostPort))
ui.Say("Waiting for SSH to become available...")
var comm packer.Communicator
var nc net.Conn
for {
if nc != nil {
nc.Close()
}
time.Sleep(5 * time.Second)
if s.cancel {
@ -101,25 +98,29 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun
}
// Attempt to connect to SSH port
nc, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", sshHostPort))
nc, err := connFunc()
if err != nil {
log.Printf("TCP connection to SSH ip/port failed: %s", err)
continue
}
nc.Close()
// Then we attempt to connect via SSH
sshConfig := &gossh.ClientConfig{
User: config.SSHUser,
Auth: []gossh.ClientAuth{
gossh.ClientAuthPassword(ssh.Password(config.SSHPassword)),
gossh.ClientAuthKeyboardInteractive(
ssh.PasswordKeyboardInteractive(config.SSHPassword)),
config := &ssh.Config{
Connection: connFunc,
SSHConfig: &gossh.ClientConfig{
User: config.SSHUser,
Auth: []gossh.ClientAuth{
gossh.ClientAuthPassword(ssh.Password(config.SSHPassword)),
gossh.ClientAuthKeyboardInteractive(
ssh.PasswordKeyboardInteractive(config.SSHPassword)),
},
},
}
sshConnectSuccess := make(chan bool, 1)
go func() {
comm, err = ssh.New(nc, sshConfig)
comm, err = ssh.New(config)
if err != nil {
log.Printf("SSH connection fail: %s", err)
sshConnectSuccess <- false
@ -143,7 +144,5 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun
break
}
// Store the connection so we can close it later
s.conn = nc
return comm, nil
}

View File

@ -9,7 +9,6 @@ import (
"github.com/mitchellh/packer/packer"
"io/ioutil"
"log"
"net"
"os"
"time"
)
@ -26,7 +25,7 @@ import (
// communicator packer.Communicator
type stepWaitForSSH struct {
cancel bool
conn net.Conn
comm packer.Communicator
}
func (s *stepWaitForSSH) Run(state map[string]interface{}) multistep.StepAction {
@ -56,6 +55,7 @@ WaitLoop:
return multistep.ActionHalt
}
s.comm = comm
state["communicator"] = comm
break WaitLoop
case <-timeout:
@ -74,9 +74,9 @@ WaitLoop:
}
func (s *stepWaitForSSH) Cleanup(map[string]interface{}) {
if s.conn != nil {
s.conn.Close()
s.conn = nil
if s.comm != nil {
// TODO: close
s.comm = nil
}
}
@ -117,12 +117,7 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun
ui.Say("Waiting for SSH to become available...")
var comm packer.Communicator
var nc net.Conn
for {
if nc != nil {
nc.Close()
}
time.Sleep(5 * time.Second)
if s.cancel {
@ -146,23 +141,28 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun
log.Printf("Detected IP: %s", ip)
// Attempt to connect to SSH port
nc, err = net.Dial("tcp", fmt.Sprintf("%s:%d", ip, config.SSHPort))
connFunc := ssh.ConnectFunc("tcp", fmt.Sprintf("%s:%d", ip, config.SSHPort))
nc, err := connFunc()
if err != nil {
log.Printf("TCP connection to SSH ip/port failed: %s", err)
continue
}
nc.Close()
// Then we attempt to connect via SSH
sshConfig := &gossh.ClientConfig{
User: config.SSHUser,
Auth: []gossh.ClientAuth{
gossh.ClientAuthPassword(ssh.Password(config.SSHPassword)),
gossh.ClientAuthKeyboardInteractive(
ssh.PasswordKeyboardInteractive(config.SSHPassword)),
config := &ssh.Config{
Connection: connFunc,
SSHConfig: &gossh.ClientConfig{
User: config.SSHUser,
Auth: []gossh.ClientAuth{
gossh.ClientAuthPassword(ssh.Password(config.SSHPassword)),
gossh.ClientAuthKeyboardInteractive(
ssh.PasswordKeyboardInteractive(config.SSHPassword)),
},
},
}
comm, err = ssh.New(nc, sshConfig)
comm, err = ssh.New(config)
if err != nil {
log.Printf("SSH handshake err: %s", err)
@ -179,7 +179,5 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun
break
}
// Store the connection so we can close it later
s.conn = nc
return comm, nil
}

View File

@ -14,13 +14,34 @@ import (
type comm struct {
client *ssh.ClientConn
config *Config
conn net.Conn
}
// Config is the structure used to configure the SSH communicator.
type Config struct {
// The configuration of the Go SSH connection
SSHConfig *ssh.ClientConfig
// Connection returns a new connection. The current connection
// in use will be closed as part of the Close method, or in the
// case an error occurs.
Connection func() (net.Conn, error)
}
// Creates a new packer.Communicator implementation over SSH. This takes
// an already existing TCP connection and SSH configuration.
func New(c net.Conn, config *ssh.ClientConfig) (result *comm, err error) {
client, err := ssh.Client(c, config)
result = &comm{client}
func New(config *Config) (result *comm, err error) {
// Establish an initial connection and connect
result = &comm{
config: config,
}
if err = result.reconnect(); err != nil {
result = nil
return
}
return
}
@ -168,3 +189,17 @@ func (c *comm) Upload(path string, input io.Reader) error {
func (c *comm) Download(string, io.Writer) error {
panic("not implemented yet")
}
func (c *comm) reconnect() (err error) {
if c.conn != nil {
c.conn.Close()
}
c.conn, err = c.config.Connection()
if err != nil {
return
}
c.client, err = ssh.Client(c.conn, c.config.SSHConfig)
return
}

View File

@ -115,12 +115,20 @@ func TestNew_Invalid(t *testing.T) {
},
}
conn, err := net.Dial("tcp", newMockLineServer(t))
if err != nil {
t.Fatalf("unable to dial to remote side: %s", err)
conn := func() (net.Conn, error) {
conn, err := net.Dial("tcp", newMockLineServer(t))
if err != nil {
t.Fatalf("unable to dial to remote side: %s", err)
}
return conn, err
}
_, err = New(conn, clientConfig)
config := &Config{
Connection: conn,
SSHConfig: clientConfig,
}
_, err := New(config)
if err == nil {
t.Fatal("should have had an error connecting")
}
@ -134,12 +142,20 @@ func TestStart(t *testing.T) {
},
}
conn, err := net.Dial("tcp", newMockLineServer(t))
if err != nil {
t.Fatalf("unable to dial to remote side: %s", err)
conn := func() (net.Conn, error) {
conn, err := net.Dial("tcp", newMockLineServer(t))
if err != nil {
t.Fatalf("unable to dial to remote side: %s", err)
}
return conn, err
}
client, err := New(conn, clientConfig)
config := &Config{
Connection: conn,
SSHConfig: clientConfig,
}
client, err := New(config)
if err != nil {
t.Fatalf("error connecting to SSH: %s", err)
}

View File

@ -0,0 +1,16 @@
package ssh
import (
"log"
"net"
)
// ConnectFunc is a convenience method for returning a function
// that just uses net.Dial to communicate with the remote end that
// is suitable for use with the SSH communicator configuration.
func ConnectFunc(network, addr string) func() (net.Conn, error) {
return func() (net.Conn, error) {
log.Printf("Opening conn for SSH to %s %s", network, addr)
return net.Dial(network, addr)
}
}