From b598baa5e30513144853a7700d27a178fe8b6ce5 Mon Sep 17 00:00:00 2001 From: David Campbell Date: Wed, 19 Oct 2016 19:50:35 -0700 Subject: [PATCH] Use SSH agent when enabled for bastion step --- helper/communicator/step_connect_ssh.go | 31 +++- helper/communicator/step_connect_ssh_test.go | 160 +++++++++++++++++++ 2 files changed, 190 insertions(+), 1 deletion(-) create mode 100644 helper/communicator/step_connect_ssh_test.go diff --git a/helper/communicator/step_connect_ssh.go b/helper/communicator/step_connect_ssh.go index 71a6d1a39..6aa0fda38 100644 --- a/helper/communicator/step_connect_ssh.go +++ b/helper/communicator/step_connect_ssh.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "net" + "os" "strings" "time" @@ -13,6 +14,7 @@ import ( "github.com/mitchellh/packer/communicator/ssh" "github.com/mitchellh/packer/packer" gossh "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" ) // StepConnectSSH is a step that only connects to SSH. @@ -94,6 +96,7 @@ func (s *StepConnectSSH) waitForSSH(state multistep.StateBag, cancel <-chan stru conf, err := sshBastionConfig(s.Config) if err != nil { + log.Printf("[ERROR] Error calling sshBastionConfig: %v", err) return nil, fmt.Errorf("Error configuring bastion: %s", err) } bConf = conf @@ -196,7 +199,15 @@ func (s *StepConnectSSH) waitForSSH(state multistep.StateBag, cancel <-chan stru } func sshBastionConfig(config *Config) (*gossh.ClientConfig, error) { - auth := make([]gossh.AuthMethod, 0, 2) + var auth []gossh.AuthMethod + + if !config.SSHDisableAgent { + log.Printf("[INFO] SSH agent forwarding enabled.") + if sshAgent := sshAgent(); sshAgent != nil { + auth = append(auth, sshAgent) + } + } + if config.SSHBastionPassword != "" { auth = append(auth, gossh.Password(config.SSHBastionPassword), @@ -218,3 +229,21 @@ func sshBastionConfig(config *Config) (*gossh.ClientConfig, error) { Auth: auth, }, nil } + +func sshAgent() gossh.AuthMethod { + socket := os.Getenv("SSH_AUTH_SOCK") + if socket == "" { + log.Println("[DEBUG] Error fetching SSH_AUTH_SOCK.") + return nil + } + + agentConn, err := net.Dial("unix", socket) + if err != nil { + log.Printf("[WARN] net.Dial: %v", err) + return nil + } + + log.Println("[INFO] Using SSH Agent.") + sshAgent := agent.NewClient(agentConn) + return gossh.PublicKeysCallback(sshAgent.Signers) +} diff --git a/helper/communicator/step_connect_ssh_test.go b/helper/communicator/step_connect_ssh_test.go new file mode 100644 index 000000000..89a7eb057 --- /dev/null +++ b/helper/communicator/step_connect_ssh_test.go @@ -0,0 +1,160 @@ +package communicator + +import ( + "bytes" + "net" + "os" + "os/exec" + "path/filepath" + "strconv" + "testing" + + "golang.org/x/crypto/ssh/agent" +) + +// startAgent executes ssh-agent, and returns a Agent interface to it. +func startAgent(t *testing.T) (agent.Agent, func()) { + if testing.Short() { + // ssh-agent is not always available, and the key + // types supported vary by platform. + t.Skip("skipping test due to -short") + } + + bin, err := exec.LookPath("ssh-agent") + if err != nil { + t.Skip("could not find ssh-agent") + } + + cmd := exec.Command(bin, "-s") + out, err := cmd.Output() + if err != nil { + t.Fatalf("cmd.Output: %v", err) + } + + /* Output looks like: + SSH_AUTH_SOCK=/tmp/ssh-P65gpcqArqvH/agent.15541; export SSH_AUTH_SOCK; + SSH_AGENT_PID=15542; export SSH_AGENT_PID; + echo Agent pid 15542; + */ + fields := bytes.Split(out, []byte(";")) + line := bytes.SplitN(fields[0], []byte("="), 2) + line[0] = bytes.TrimLeft(line[0], "\n") + if string(line[0]) != "SSH_AUTH_SOCK" { + t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0]) + } + socket := string(line[1]) + t.Logf("Socket value: %v", socket) + + origSocket := os.Getenv("SSH_AUTH_SOCK") + if err := os.Setenv("SSH_AUTH_SOCK", socket); err != nil { + t.Fatalf("could not set SSH_AUTH_SOCK environment variable: %v", err) + } + + line = bytes.SplitN(fields[2], []byte("="), 2) + line[0] = bytes.TrimLeft(line[0], "\n") + if string(line[0]) != "SSH_AGENT_PID" { + t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2]) + } + pidStr := line[1] + t.Logf("Agent PID: %v", string(pidStr)) + pid, err := strconv.Atoi(string(pidStr)) + if err != nil { + t.Fatalf("Atoi(%q): %v", pidStr, err) + } + + conn, err := net.Dial("unix", string(socket)) + if err != nil { + t.Fatalf("net.Dial: %v", err) + } + + return agent.NewClient(conn), func() { + proc, _ := os.FindProcess(pid) + if proc != nil { + proc.Kill() + } + + os.Setenv("SSH_AUTH_SOCK", origSocket) + conn.Close() + os.RemoveAll(filepath.Dir(socket)) + } +} + +func TestSSHAgent(t *testing.T) { + _, cleanup := startAgent(t) + defer cleanup() + + if auth := sshAgent(); auth == nil { + t.Error("Want `ssh.AuthMethod`, got `nil`") + } +} + +func TestSSHBastionConfig(t *testing.T) { + pemPath := TestPEM(t) + tests := []struct { + in *Config + errStr string + want int + fn func() func() + }{ + { + in: &Config{SSHDisableAgent: true}, + want: 0, + }, + { + in: &Config{SSHDisableAgent: false}, + want: 0, + fn: func() func() { + _, cleanup := startAgent(t) + os.Unsetenv("SSH_AUTH_SOCK") + return cleanup + }, + }, + { + in: &Config{ + SSHDisableAgent: false, + SSHBastionPassword: "foobar", + SSHBastionPrivateKey: pemPath, + }, + want: 4, + fn: func() func() { + _, cleanup := startAgent(t) + return cleanup + }, + }, + { + in: &Config{ + SSHBastionPrivateKey: pemPath, + }, + want: 0, + errStr: "Failed to read key '" + pemPath + "': no key found", + fn: func() func() { + os.Truncate(pemPath, 0) + return func() { + if err := os.Remove(pemPath); err != nil { + t.Fatalf("os.Remove: %v", err) + } + } + }, + }, + } + + for _, c := range tests { + func() { + if c.fn != nil { + defered := c.fn() + defer defered() + } + bConf, err := sshBastionConfig(c.in) + if err != nil { + if err.Error() != c.errStr { + t.Errorf("want error %v, got %q", c.errStr, err) + } + return + } + + if len(bConf.Auth) != c.want { + t.Errorf("want %v ssh.AuthMethod, got %v ssh.AuthMethod", c.want, len(bConf.Auth)) + } + }() + } +}