communicator/ssh: have a Connection func so we can re-establish

[GH-152]
This commit is contained in:
Mitchell Hashimoto 2013-07-14 20:22:41 +09:00
parent db644c91fb
commit 9718a4656c
7 changed files with 173 additions and 110 deletions

View File

@ -9,13 +9,12 @@ import (
"github.com/mitchellh/packer/communicator/ssh" "github.com/mitchellh/packer/communicator/ssh"
"github.com/mitchellh/packer/packer" "github.com/mitchellh/packer/packer"
"log" "log"
"net"
"time" "time"
) )
type stepConnectSSH struct { type stepConnectSSH struct {
cancel bool cancel bool
conn net.Conn comm packer.Communicator
} }
func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction { func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction {
@ -45,6 +44,7 @@ WaitLoop:
return multistep.ActionHalt return multistep.ActionHalt
} }
s.comm = comm
state["communicator"] = comm state["communicator"] = comm
break WaitLoop break WaitLoop
case <-timeout: case <-timeout:
@ -63,9 +63,9 @@ WaitLoop:
} }
func (s *stepConnectSSH) Cleanup(map[string]interface{}) { func (s *stepConnectSSH) Cleanup(map[string]interface{}) {
if s.conn != nil { if s.comm != nil {
s.conn.Close() // Close it TODO
s.conn = nil s.comm = nil
} }
} }
@ -85,14 +85,13 @@ func (s *stepConnectSSH) waitForSSH(state map[string]interface{}) (packer.Commun
return nil, fmt.Errorf("Error setting up SSH config: %s", err) return nil, fmt.Errorf("Error setting up SSH config: %s", err)
} }
// Create the function that will be used to create the connection
connFunc := ssh.ConnectFunc(
"tcp", fmt.Sprintf("%s:%d", instance.DNSName, config.SSHPort))
ui.Say("Waiting for SSH to become available...") ui.Say("Waiting for SSH to become available...")
var comm packer.Communicator var comm packer.Communicator
var nc net.Conn
for { for {
if nc != nil {
nc.Close()
}
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
if s.cancel { if s.cancel {
@ -100,28 +99,29 @@ func (s *stepConnectSSH) waitForSSH(state map[string]interface{}) (packer.Commun
return nil, errors.New("SSH wait cancelled") return nil, errors.New("SSH wait cancelled")
} }
// Attempt to connect to SSH port // First just attempt a normal TCP connection that we close right
log.Printf( // away. We just test this in order to wait for the TCP port to be ready.
"Opening TCP conn for SSH to %s:%d", nc, err := connFunc()
instance.DNSName, config.SSHPort)
nc, err := net.Dial("tcp",
fmt.Sprintf("%s:%d", instance.DNSName, config.SSHPort))
if err != nil { if err != nil {
log.Printf("TCP connection to SSH ip/port failed: %s", err) log.Printf("TCP connection to SSH ip/port failed: %s", err)
continue continue
} }
nc.Close()
// Build the actual SSH client configuration // Build the configuration to connect to SSH
sshConfig := &gossh.ClientConfig{ config := &ssh.Config{
User: config.SSHUsername, Connection: connFunc,
Auth: []gossh.ClientAuth{ SSHConfig: &gossh.ClientConfig{
gossh.ClientAuthKeyring(keyring), User: config.SSHUsername,
Auth: []gossh.ClientAuth{
gossh.ClientAuthKeyring(keyring),
},
}, },
} }
sshConnectSuccess := make(chan bool, 1) sshConnectSuccess := make(chan bool, 1)
go func() { go func() {
comm, err = ssh.New(nc, sshConfig) comm, err = ssh.New(config)
if err != nil { if err != nil {
log.Printf("SSH connection fail: %s", err) log.Printf("SSH connection fail: %s", err)
sshConnectSuccess <- false sshConnectSuccess <- false
@ -145,7 +145,5 @@ func (s *stepConnectSSH) waitForSSH(state map[string]interface{}) (packer.Commun
break break
} }
// Store the connection so we can close it later
s.conn = nc
return comm, nil return comm, nil
} }

View File

@ -8,12 +8,11 @@ import (
"github.com/mitchellh/packer/communicator/ssh" "github.com/mitchellh/packer/communicator/ssh"
"github.com/mitchellh/packer/packer" "github.com/mitchellh/packer/packer"
"log" "log"
"net"
"time" "time"
) )
type stepConnectSSH struct { type stepConnectSSH struct {
conn net.Conn comm packer.Communicator
} }
func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction { func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction {
@ -33,11 +32,16 @@ func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction
return multistep.ActionHalt return multistep.ActionHalt
} }
connFunc := ssh.ConnectFunc("tcp", fmt.Sprintf("%s:%d", ipAddress, config.SSHPort))
// Build the actual SSH client configuration // Build the actual SSH client configuration
sshConfig := &gossh.ClientConfig{ sshConfig := &ssh.Config{
User: config.SSHUsername, Connection: connFunc,
Auth: []gossh.ClientAuth{ SSHConfig: &gossh.ClientConfig{
gossh.ClientAuthKeyring(keyring), User: config.SSHUsername,
Auth: []gossh.ClientAuth{
gossh.ClientAuthKeyring(keyring),
},
}, },
} }
@ -50,8 +54,6 @@ func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction
var comm packer.Communicator var comm packer.Communicator
go func() { go func() {
var err error
ui.Say("Connecting to the droplet via SSH...") ui.Say("Connecting to the droplet via SSH...")
attempts := 0 attempts := 0
handshakeAttempts := 0 handshakeAttempts := 0
@ -62,34 +64,31 @@ func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction
default: default:
} }
attempts += 1
log.Printf(
"Opening TCP conn for SSH to %s:%d (attempt %d)",
ipAddress, config.SSHPort, attempts)
s.conn, err = net.DialTimeout(
"tcp",
fmt.Sprintf("%s:%d", ipAddress, config.SSHPort),
10*time.Second)
if err == nil {
log.Println("TCP connection made. Attempting SSH handshake.")
comm, err = ssh.New(s.conn, sshConfig)
if err == nil {
log.Println("Connected to SSH!")
break
}
handshakeAttempts += 1
log.Printf("SSH handshake error: %s", err)
if handshakeAttempts > 5 {
connected <- err
return
}
}
// A brief sleep so we're not being overly zealous attempting // A brief sleep so we're not being overly zealous attempting
// to connect to the instance. // to connect to the instance.
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
attempts += 1
nc, err := connFunc()
if err != nil {
continue
}
nc.Close()
log.Println("TCP connection made. Attempting SSH handshake.")
comm, err = ssh.New(sshConfig)
if err == nil {
log.Println("Connected to SSH!")
break
}
handshakeAttempts += 1
log.Printf("SSH handshake error: %s", err)
if handshakeAttempts > 5 {
connected <- err
return
}
} }
connected <- nil connected <- nil
@ -125,13 +124,15 @@ ConnectWaitLoop:
} }
// Set the communicator on the state bag so it can be used later // Set the communicator on the state bag so it can be used later
s.comm = comm
state["communicator"] = comm state["communicator"] = comm
return multistep.ActionContinue return multistep.ActionContinue
} }
func (s *stepConnectSSH) Cleanup(map[string]interface{}) { func (s *stepConnectSSH) Cleanup(map[string]interface{}) {
if s.conn != nil { if s.comm != nil {
s.conn.Close() // TODO: close
s.comm = nil
} }
} }

View File

@ -8,7 +8,6 @@ import (
"github.com/mitchellh/packer/communicator/ssh" "github.com/mitchellh/packer/communicator/ssh"
"github.com/mitchellh/packer/packer" "github.com/mitchellh/packer/packer"
"log" "log"
"net"
"time" "time"
) )
@ -24,7 +23,7 @@ import (
// communicator packer.Communicator // communicator packer.Communicator
type stepWaitForSSH struct { type stepWaitForSSH struct {
cancel bool cancel bool
conn net.Conn comm packer.Communicator
} }
func (s *stepWaitForSSH) Run(state map[string]interface{}) multistep.StepAction { func (s *stepWaitForSSH) Run(state map[string]interface{}) multistep.StepAction {
@ -54,6 +53,7 @@ WaitLoop:
return multistep.ActionHalt return multistep.ActionHalt
} }
s.comm = comm
state["communicator"] = comm state["communicator"] = comm
break WaitLoop break WaitLoop
case <-timeout: case <-timeout:
@ -72,9 +72,9 @@ WaitLoop:
} }
func (s *stepWaitForSSH) Cleanup(map[string]interface{}) { func (s *stepWaitForSSH) Cleanup(map[string]interface{}) {
if s.conn != nil { if s.comm != nil {
s.conn.Close() // TODO: close
s.conn = nil s.comm = nil
} }
} }
@ -85,14 +85,11 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun
ui := state["ui"].(packer.Ui) ui := state["ui"].(packer.Ui)
sshHostPort := state["sshHostPort"].(uint) sshHostPort := state["sshHostPort"].(uint)
connFunc := ssh.ConnectFunc("tcp", fmt.Sprintf("127.0.0.1:%d", sshHostPort))
ui.Say("Waiting for SSH to become available...") ui.Say("Waiting for SSH to become available...")
var comm packer.Communicator var comm packer.Communicator
var nc net.Conn
for { for {
if nc != nil {
nc.Close()
}
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
if s.cancel { if s.cancel {
@ -101,25 +98,29 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun
} }
// Attempt to connect to SSH port // Attempt to connect to SSH port
nc, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", sshHostPort)) nc, err := connFunc()
if err != nil { if err != nil {
log.Printf("TCP connection to SSH ip/port failed: %s", err) log.Printf("TCP connection to SSH ip/port failed: %s", err)
continue continue
} }
nc.Close()
// Then we attempt to connect via SSH // Then we attempt to connect via SSH
sshConfig := &gossh.ClientConfig{ config := &ssh.Config{
User: config.SSHUser, Connection: connFunc,
Auth: []gossh.ClientAuth{ SSHConfig: &gossh.ClientConfig{
gossh.ClientAuthPassword(ssh.Password(config.SSHPassword)), User: config.SSHUser,
gossh.ClientAuthKeyboardInteractive( Auth: []gossh.ClientAuth{
ssh.PasswordKeyboardInteractive(config.SSHPassword)), gossh.ClientAuthPassword(ssh.Password(config.SSHPassword)),
gossh.ClientAuthKeyboardInteractive(
ssh.PasswordKeyboardInteractive(config.SSHPassword)),
},
}, },
} }
sshConnectSuccess := make(chan bool, 1) sshConnectSuccess := make(chan bool, 1)
go func() { go func() {
comm, err = ssh.New(nc, sshConfig) comm, err = ssh.New(config)
if err != nil { if err != nil {
log.Printf("SSH connection fail: %s", err) log.Printf("SSH connection fail: %s", err)
sshConnectSuccess <- false sshConnectSuccess <- false
@ -143,7 +144,5 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun
break break
} }
// Store the connection so we can close it later
s.conn = nc
return comm, nil return comm, nil
} }

View File

@ -9,7 +9,6 @@ import (
"github.com/mitchellh/packer/packer" "github.com/mitchellh/packer/packer"
"io/ioutil" "io/ioutil"
"log" "log"
"net"
"os" "os"
"time" "time"
) )
@ -26,7 +25,7 @@ import (
// communicator packer.Communicator // communicator packer.Communicator
type stepWaitForSSH struct { type stepWaitForSSH struct {
cancel bool cancel bool
conn net.Conn comm packer.Communicator
} }
func (s *stepWaitForSSH) Run(state map[string]interface{}) multistep.StepAction { func (s *stepWaitForSSH) Run(state map[string]interface{}) multistep.StepAction {
@ -56,6 +55,7 @@ WaitLoop:
return multistep.ActionHalt return multistep.ActionHalt
} }
s.comm = comm
state["communicator"] = comm state["communicator"] = comm
break WaitLoop break WaitLoop
case <-timeout: case <-timeout:
@ -74,9 +74,9 @@ WaitLoop:
} }
func (s *stepWaitForSSH) Cleanup(map[string]interface{}) { func (s *stepWaitForSSH) Cleanup(map[string]interface{}) {
if s.conn != nil { if s.comm != nil {
s.conn.Close() // TODO: close
s.conn = nil s.comm = nil
} }
} }
@ -117,12 +117,7 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun
ui.Say("Waiting for SSH to become available...") ui.Say("Waiting for SSH to become available...")
var comm packer.Communicator var comm packer.Communicator
var nc net.Conn
for { for {
if nc != nil {
nc.Close()
}
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
if s.cancel { if s.cancel {
@ -146,23 +141,28 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun
log.Printf("Detected IP: %s", ip) log.Printf("Detected IP: %s", ip)
// Attempt to connect to SSH port // Attempt to connect to SSH port
nc, err = net.Dial("tcp", fmt.Sprintf("%s:%d", ip, config.SSHPort)) connFunc := ssh.ConnectFunc("tcp", fmt.Sprintf("%s:%d", ip, config.SSHPort))
nc, err := connFunc()
if err != nil { if err != nil {
log.Printf("TCP connection to SSH ip/port failed: %s", err) log.Printf("TCP connection to SSH ip/port failed: %s", err)
continue continue
} }
nc.Close()
// Then we attempt to connect via SSH // Then we attempt to connect via SSH
sshConfig := &gossh.ClientConfig{ config := &ssh.Config{
User: config.SSHUser, Connection: connFunc,
Auth: []gossh.ClientAuth{ SSHConfig: &gossh.ClientConfig{
gossh.ClientAuthPassword(ssh.Password(config.SSHPassword)), User: config.SSHUser,
gossh.ClientAuthKeyboardInteractive( Auth: []gossh.ClientAuth{
ssh.PasswordKeyboardInteractive(config.SSHPassword)), gossh.ClientAuthPassword(ssh.Password(config.SSHPassword)),
gossh.ClientAuthKeyboardInteractive(
ssh.PasswordKeyboardInteractive(config.SSHPassword)),
},
}, },
} }
comm, err = ssh.New(nc, sshConfig) comm, err = ssh.New(config)
if err != nil { if err != nil {
log.Printf("SSH handshake err: %s", err) log.Printf("SSH handshake err: %s", err)
@ -179,7 +179,5 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun
break break
} }
// Store the connection so we can close it later
s.conn = nc
return comm, nil return comm, nil
} }

View File

@ -14,13 +14,34 @@ import (
type comm struct { type comm struct {
client *ssh.ClientConn client *ssh.ClientConn
config *Config
conn net.Conn
}
// Config is the structure used to configure the SSH communicator.
type Config struct {
// The configuration of the Go SSH connection
SSHConfig *ssh.ClientConfig
// Connection returns a new connection. The current connection
// in use will be closed as part of the Close method, or in the
// case an error occurs.
Connection func() (net.Conn, error)
} }
// Creates a new packer.Communicator implementation over SSH. This takes // Creates a new packer.Communicator implementation over SSH. This takes
// an already existing TCP connection and SSH configuration. // an already existing TCP connection and SSH configuration.
func New(c net.Conn, config *ssh.ClientConfig) (result *comm, err error) { func New(config *Config) (result *comm, err error) {
client, err := ssh.Client(c, config) // Establish an initial connection and connect
result = &comm{client} result = &comm{
config: config,
}
if err = result.reconnect(); err != nil {
result = nil
return
}
return return
} }
@ -168,3 +189,17 @@ func (c *comm) Upload(path string, input io.Reader) error {
func (c *comm) Download(string, io.Writer) error { func (c *comm) Download(string, io.Writer) error {
panic("not implemented yet") panic("not implemented yet")
} }
func (c *comm) reconnect() (err error) {
if c.conn != nil {
c.conn.Close()
}
c.conn, err = c.config.Connection()
if err != nil {
return
}
c.client, err = ssh.Client(c.conn, c.config.SSHConfig)
return
}

View File

@ -115,12 +115,20 @@ func TestNew_Invalid(t *testing.T) {
}, },
} }
conn, err := net.Dial("tcp", newMockLineServer(t)) conn := func() (net.Conn, error) {
if err != nil { conn, err := net.Dial("tcp", newMockLineServer(t))
t.Fatalf("unable to dial to remote side: %s", err) if err != nil {
t.Fatalf("unable to dial to remote side: %s", err)
}
return conn, err
} }
_, err = New(conn, clientConfig) config := &Config{
Connection: conn,
SSHConfig: clientConfig,
}
_, err := New(config)
if err == nil { if err == nil {
t.Fatal("should have had an error connecting") t.Fatal("should have had an error connecting")
} }
@ -134,12 +142,20 @@ func TestStart(t *testing.T) {
}, },
} }
conn, err := net.Dial("tcp", newMockLineServer(t)) conn := func() (net.Conn, error) {
if err != nil { conn, err := net.Dial("tcp", newMockLineServer(t))
t.Fatalf("unable to dial to remote side: %s", err) if err != nil {
t.Fatalf("unable to dial to remote side: %s", err)
}
return conn, err
} }
client, err := New(conn, clientConfig) config := &Config{
Connection: conn,
SSHConfig: clientConfig,
}
client, err := New(config)
if err != nil { if err != nil {
t.Fatalf("error connecting to SSH: %s", err) t.Fatalf("error connecting to SSH: %s", err)
} }

View File

@ -0,0 +1,16 @@
package ssh
import (
"log"
"net"
)
// ConnectFunc is a convenience method for returning a function
// that just uses net.Dial to communicate with the remote end that
// is suitable for use with the SSH communicator configuration.
func ConnectFunc(network, addr string) func() (net.Conn, error) {
return func() (net.Conn, error) {
log.Printf("Opening conn for SSH to %s %s", network, addr)
return net.Dial(network, addr)
}
}