diff --git a/builder/amazonebs/step_connect_ssh.go b/builder/amazonebs/step_connect_ssh.go index 903439341..5a3c71222 100644 --- a/builder/amazonebs/step_connect_ssh.go +++ b/builder/amazonebs/step_connect_ssh.go @@ -9,13 +9,12 @@ import ( "github.com/mitchellh/packer/communicator/ssh" "github.com/mitchellh/packer/packer" "log" - "net" "time" ) type stepConnectSSH struct { cancel bool - conn net.Conn + comm packer.Communicator } func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction { @@ -45,6 +44,7 @@ WaitLoop: return multistep.ActionHalt } + s.comm = comm state["communicator"] = comm break WaitLoop case <-timeout: @@ -63,9 +63,9 @@ WaitLoop: } func (s *stepConnectSSH) Cleanup(map[string]interface{}) { - if s.conn != nil { - s.conn.Close() - s.conn = nil + if s.comm != nil { + // Close it TODO + s.comm = nil } } @@ -85,14 +85,13 @@ func (s *stepConnectSSH) waitForSSH(state map[string]interface{}) (packer.Commun return nil, fmt.Errorf("Error setting up SSH config: %s", err) } + // Create the function that will be used to create the connection + connFunc := ssh.ConnectFunc( + "tcp", fmt.Sprintf("%s:%d", instance.DNSName, config.SSHPort)) + ui.Say("Waiting for SSH to become available...") var comm packer.Communicator - var nc net.Conn for { - if nc != nil { - nc.Close() - } - time.Sleep(5 * time.Second) if s.cancel { @@ -100,28 +99,29 @@ func (s *stepConnectSSH) waitForSSH(state map[string]interface{}) (packer.Commun return nil, errors.New("SSH wait cancelled") } - // Attempt to connect to SSH port - log.Printf( - "Opening TCP conn for SSH to %s:%d", - instance.DNSName, config.SSHPort) - nc, err := net.Dial("tcp", - fmt.Sprintf("%s:%d", instance.DNSName, config.SSHPort)) + // First just attempt a normal TCP connection that we close right + // away. We just test this in order to wait for the TCP port to be ready. + nc, err := connFunc() if err != nil { log.Printf("TCP connection to SSH ip/port failed: %s", err) continue } + nc.Close() - // Build the actual SSH client configuration - sshConfig := &gossh.ClientConfig{ - User: config.SSHUsername, - Auth: []gossh.ClientAuth{ - gossh.ClientAuthKeyring(keyring), + // Build the configuration to connect to SSH + config := &ssh.Config{ + Connection: connFunc, + SSHConfig: &gossh.ClientConfig{ + User: config.SSHUsername, + Auth: []gossh.ClientAuth{ + gossh.ClientAuthKeyring(keyring), + }, }, } sshConnectSuccess := make(chan bool, 1) go func() { - comm, err = ssh.New(nc, sshConfig) + comm, err = ssh.New(config) if err != nil { log.Printf("SSH connection fail: %s", err) sshConnectSuccess <- false @@ -145,7 +145,5 @@ func (s *stepConnectSSH) waitForSSH(state map[string]interface{}) (packer.Commun break } - // Store the connection so we can close it later - s.conn = nc return comm, nil } diff --git a/builder/digitalocean/step_connect_ssh.go b/builder/digitalocean/step_connect_ssh.go index c6fcf7db6..dd9978123 100644 --- a/builder/digitalocean/step_connect_ssh.go +++ b/builder/digitalocean/step_connect_ssh.go @@ -8,12 +8,11 @@ import ( "github.com/mitchellh/packer/communicator/ssh" "github.com/mitchellh/packer/packer" "log" - "net" "time" ) type stepConnectSSH struct { - conn net.Conn + comm packer.Communicator } func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction { @@ -33,11 +32,16 @@ func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction return multistep.ActionHalt } + connFunc := ssh.ConnectFunc("tcp", fmt.Sprintf("%s:%d", ipAddress, config.SSHPort)) + // Build the actual SSH client configuration - sshConfig := &gossh.ClientConfig{ - User: config.SSHUsername, - Auth: []gossh.ClientAuth{ - gossh.ClientAuthKeyring(keyring), + sshConfig := &ssh.Config{ + Connection: connFunc, + SSHConfig: &gossh.ClientConfig{ + User: config.SSHUsername, + Auth: []gossh.ClientAuth{ + gossh.ClientAuthKeyring(keyring), + }, }, } @@ -50,8 +54,6 @@ func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction var comm packer.Communicator go func() { - var err error - ui.Say("Connecting to the droplet via SSH...") attempts := 0 handshakeAttempts := 0 @@ -62,34 +64,31 @@ func (s *stepConnectSSH) Run(state map[string]interface{}) multistep.StepAction default: } - attempts += 1 - log.Printf( - "Opening TCP conn for SSH to %s:%d (attempt %d)", - ipAddress, config.SSHPort, attempts) - s.conn, err = net.DialTimeout( - "tcp", - fmt.Sprintf("%s:%d", ipAddress, config.SSHPort), - 10*time.Second) - if err == nil { - log.Println("TCP connection made. Attempting SSH handshake.") - comm, err = ssh.New(s.conn, sshConfig) - if err == nil { - log.Println("Connected to SSH!") - break - } - - handshakeAttempts += 1 - log.Printf("SSH handshake error: %s", err) - - if handshakeAttempts > 5 { - connected <- err - return - } - } - // A brief sleep so we're not being overly zealous attempting // to connect to the instance. time.Sleep(500 * time.Millisecond) + + attempts += 1 + nc, err := connFunc() + if err != nil { + continue + } + nc.Close() + + log.Println("TCP connection made. Attempting SSH handshake.") + comm, err = ssh.New(sshConfig) + if err == nil { + log.Println("Connected to SSH!") + break + } + + handshakeAttempts += 1 + log.Printf("SSH handshake error: %s", err) + + if handshakeAttempts > 5 { + connected <- err + return + } } connected <- nil @@ -125,13 +124,15 @@ ConnectWaitLoop: } // Set the communicator on the state bag so it can be used later + s.comm = comm state["communicator"] = comm return multistep.ActionContinue } func (s *stepConnectSSH) Cleanup(map[string]interface{}) { - if s.conn != nil { - s.conn.Close() + if s.comm != nil { + // TODO: close + s.comm = nil } } diff --git a/builder/virtualbox/step_wait_for_ssh.go b/builder/virtualbox/step_wait_for_ssh.go index ac039e801..ca07c05ac 100644 --- a/builder/virtualbox/step_wait_for_ssh.go +++ b/builder/virtualbox/step_wait_for_ssh.go @@ -8,7 +8,6 @@ import ( "github.com/mitchellh/packer/communicator/ssh" "github.com/mitchellh/packer/packer" "log" - "net" "time" ) @@ -24,7 +23,7 @@ import ( // communicator packer.Communicator type stepWaitForSSH struct { cancel bool - conn net.Conn + comm packer.Communicator } func (s *stepWaitForSSH) Run(state map[string]interface{}) multistep.StepAction { @@ -54,6 +53,7 @@ WaitLoop: return multistep.ActionHalt } + s.comm = comm state["communicator"] = comm break WaitLoop case <-timeout: @@ -72,9 +72,9 @@ WaitLoop: } func (s *stepWaitForSSH) Cleanup(map[string]interface{}) { - if s.conn != nil { - s.conn.Close() - s.conn = nil + if s.comm != nil { + // TODO: close + s.comm = nil } } @@ -85,14 +85,11 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun ui := state["ui"].(packer.Ui) sshHostPort := state["sshHostPort"].(uint) + connFunc := ssh.ConnectFunc("tcp", fmt.Sprintf("127.0.0.1:%d", sshHostPort)) + ui.Say("Waiting for SSH to become available...") var comm packer.Communicator - var nc net.Conn for { - if nc != nil { - nc.Close() - } - time.Sleep(5 * time.Second) if s.cancel { @@ -101,25 +98,29 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun } // Attempt to connect to SSH port - nc, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", sshHostPort)) + nc, err := connFunc() if err != nil { log.Printf("TCP connection to SSH ip/port failed: %s", err) continue } + nc.Close() // Then we attempt to connect via SSH - sshConfig := &gossh.ClientConfig{ - User: config.SSHUser, - Auth: []gossh.ClientAuth{ - gossh.ClientAuthPassword(ssh.Password(config.SSHPassword)), - gossh.ClientAuthKeyboardInteractive( - ssh.PasswordKeyboardInteractive(config.SSHPassword)), + config := &ssh.Config{ + Connection: connFunc, + SSHConfig: &gossh.ClientConfig{ + User: config.SSHUser, + Auth: []gossh.ClientAuth{ + gossh.ClientAuthPassword(ssh.Password(config.SSHPassword)), + gossh.ClientAuthKeyboardInteractive( + ssh.PasswordKeyboardInteractive(config.SSHPassword)), + }, }, } sshConnectSuccess := make(chan bool, 1) go func() { - comm, err = ssh.New(nc, sshConfig) + comm, err = ssh.New(config) if err != nil { log.Printf("SSH connection fail: %s", err) sshConnectSuccess <- false @@ -143,7 +144,5 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun break } - // Store the connection so we can close it later - s.conn = nc return comm, nil } diff --git a/builder/vmware/step_wait_for_ssh.go b/builder/vmware/step_wait_for_ssh.go index 03457ad9f..9df52f82e 100644 --- a/builder/vmware/step_wait_for_ssh.go +++ b/builder/vmware/step_wait_for_ssh.go @@ -9,7 +9,6 @@ import ( "github.com/mitchellh/packer/packer" "io/ioutil" "log" - "net" "os" "time" ) @@ -26,7 +25,7 @@ import ( // communicator packer.Communicator type stepWaitForSSH struct { cancel bool - conn net.Conn + comm packer.Communicator } func (s *stepWaitForSSH) Run(state map[string]interface{}) multistep.StepAction { @@ -56,6 +55,7 @@ WaitLoop: return multistep.ActionHalt } + s.comm = comm state["communicator"] = comm break WaitLoop case <-timeout: @@ -74,9 +74,9 @@ WaitLoop: } func (s *stepWaitForSSH) Cleanup(map[string]interface{}) { - if s.conn != nil { - s.conn.Close() - s.conn = nil + if s.comm != nil { + // TODO: close + s.comm = nil } } @@ -117,12 +117,7 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun ui.Say("Waiting for SSH to become available...") var comm packer.Communicator - var nc net.Conn for { - if nc != nil { - nc.Close() - } - time.Sleep(5 * time.Second) if s.cancel { @@ -146,23 +141,28 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun log.Printf("Detected IP: %s", ip) // Attempt to connect to SSH port - nc, err = net.Dial("tcp", fmt.Sprintf("%s:%d", ip, config.SSHPort)) + connFunc := ssh.ConnectFunc("tcp", fmt.Sprintf("%s:%d", ip, config.SSHPort)) + nc, err := connFunc() if err != nil { log.Printf("TCP connection to SSH ip/port failed: %s", err) continue } + nc.Close() // Then we attempt to connect via SSH - sshConfig := &gossh.ClientConfig{ - User: config.SSHUser, - Auth: []gossh.ClientAuth{ - gossh.ClientAuthPassword(ssh.Password(config.SSHPassword)), - gossh.ClientAuthKeyboardInteractive( - ssh.PasswordKeyboardInteractive(config.SSHPassword)), + config := &ssh.Config{ + Connection: connFunc, + SSHConfig: &gossh.ClientConfig{ + User: config.SSHUser, + Auth: []gossh.ClientAuth{ + gossh.ClientAuthPassword(ssh.Password(config.SSHPassword)), + gossh.ClientAuthKeyboardInteractive( + ssh.PasswordKeyboardInteractive(config.SSHPassword)), + }, }, } - comm, err = ssh.New(nc, sshConfig) + comm, err = ssh.New(config) if err != nil { log.Printf("SSH handshake err: %s", err) @@ -179,7 +179,5 @@ func (s *stepWaitForSSH) waitForSSH(state map[string]interface{}) (packer.Commun break } - // Store the connection so we can close it later - s.conn = nc return comm, nil } diff --git a/communicator/ssh/communicator.go b/communicator/ssh/communicator.go index 50bf9c9e3..ba5070ae6 100644 --- a/communicator/ssh/communicator.go +++ b/communicator/ssh/communicator.go @@ -14,13 +14,34 @@ import ( type comm struct { client *ssh.ClientConn + config *Config + conn net.Conn +} + +// Config is the structure used to configure the SSH communicator. +type Config struct { + // The configuration of the Go SSH connection + SSHConfig *ssh.ClientConfig + + // Connection returns a new connection. The current connection + // in use will be closed as part of the Close method, or in the + // case an error occurs. + Connection func() (net.Conn, error) } // Creates a new packer.Communicator implementation over SSH. This takes // an already existing TCP connection and SSH configuration. -func New(c net.Conn, config *ssh.ClientConfig) (result *comm, err error) { - client, err := ssh.Client(c, config) - result = &comm{client} +func New(config *Config) (result *comm, err error) { + // Establish an initial connection and connect + result = &comm{ + config: config, + } + + if err = result.reconnect(); err != nil { + result = nil + return + } + return } @@ -168,3 +189,17 @@ func (c *comm) Upload(path string, input io.Reader) error { func (c *comm) Download(string, io.Writer) error { panic("not implemented yet") } + +func (c *comm) reconnect() (err error) { + if c.conn != nil { + c.conn.Close() + } + + c.conn, err = c.config.Connection() + if err != nil { + return + } + + c.client, err = ssh.Client(c.conn, c.config.SSHConfig) + return +} diff --git a/communicator/ssh/communicator_test.go b/communicator/ssh/communicator_test.go index 1831bac90..42ba0bb7e 100644 --- a/communicator/ssh/communicator_test.go +++ b/communicator/ssh/communicator_test.go @@ -115,12 +115,20 @@ func TestNew_Invalid(t *testing.T) { }, } - conn, err := net.Dial("tcp", newMockLineServer(t)) - if err != nil { - t.Fatalf("unable to dial to remote side: %s", err) + conn := func() (net.Conn, error) { + conn, err := net.Dial("tcp", newMockLineServer(t)) + if err != nil { + t.Fatalf("unable to dial to remote side: %s", err) + } + return conn, err } - _, err = New(conn, clientConfig) + config := &Config{ + Connection: conn, + SSHConfig: clientConfig, + } + + _, err := New(config) if err == nil { t.Fatal("should have had an error connecting") } @@ -134,12 +142,20 @@ func TestStart(t *testing.T) { }, } - conn, err := net.Dial("tcp", newMockLineServer(t)) - if err != nil { - t.Fatalf("unable to dial to remote side: %s", err) + conn := func() (net.Conn, error) { + conn, err := net.Dial("tcp", newMockLineServer(t)) + if err != nil { + t.Fatalf("unable to dial to remote side: %s", err) + } + return conn, err } - client, err := New(conn, clientConfig) + config := &Config{ + Connection: conn, + SSHConfig: clientConfig, + } + + client, err := New(config) if err != nil { t.Fatalf("error connecting to SSH: %s", err) } diff --git a/communicator/ssh/connect.go b/communicator/ssh/connect.go new file mode 100644 index 000000000..cfd55b348 --- /dev/null +++ b/communicator/ssh/connect.go @@ -0,0 +1,16 @@ +package ssh + +import ( + "log" + "net" +) + +// ConnectFunc is a convenience method for returning a function +// that just uses net.Dial to communicate with the remote end that +// is suitable for use with the SSH communicator configuration. +func ConnectFunc(network, addr string) func() (net.Conn, error) { + return func() (net.Conn, error) { + log.Printf("Opening conn for SSH to %s %s", network, addr) + return net.Dial(network, addr) + } +}