Use SSH agent when enabled for bastion step

This commit is contained in:
David Campbell 2016-10-19 19:50:35 -07:00 committed by Matthew Hooker
parent 16e1e488d4
commit b598baa5e3
No known key found for this signature in database
GPG Key ID: 7B5F933D9CE8C6A1
2 changed files with 190 additions and 1 deletions

View File

@ -5,6 +5,7 @@ import (
"fmt"
"log"
"net"
"os"
"strings"
"time"
@ -13,6 +14,7 @@ import (
"github.com/mitchellh/packer/communicator/ssh"
"github.com/mitchellh/packer/packer"
gossh "golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
// StepConnectSSH is a step that only connects to SSH.
@ -94,6 +96,7 @@ func (s *StepConnectSSH) waitForSSH(state multistep.StateBag, cancel <-chan stru
conf, err := sshBastionConfig(s.Config)
if err != nil {
log.Printf("[ERROR] Error calling sshBastionConfig: %v", err)
return nil, fmt.Errorf("Error configuring bastion: %s", err)
}
bConf = conf
@ -196,7 +199,15 @@ func (s *StepConnectSSH) waitForSSH(state multistep.StateBag, cancel <-chan stru
}
func sshBastionConfig(config *Config) (*gossh.ClientConfig, error) {
auth := make([]gossh.AuthMethod, 0, 2)
var auth []gossh.AuthMethod
if !config.SSHDisableAgent {
log.Printf("[INFO] SSH agent forwarding enabled.")
if sshAgent := sshAgent(); sshAgent != nil {
auth = append(auth, sshAgent)
}
}
if config.SSHBastionPassword != "" {
auth = append(auth,
gossh.Password(config.SSHBastionPassword),
@ -218,3 +229,21 @@ func sshBastionConfig(config *Config) (*gossh.ClientConfig, error) {
Auth: auth,
}, nil
}
func sshAgent() gossh.AuthMethod {
socket := os.Getenv("SSH_AUTH_SOCK")
if socket == "" {
log.Println("[DEBUG] Error fetching SSH_AUTH_SOCK.")
return nil
}
agentConn, err := net.Dial("unix", socket)
if err != nil {
log.Printf("[WARN] net.Dial: %v", err)
return nil
}
log.Println("[INFO] Using SSH Agent.")
sshAgent := agent.NewClient(agentConn)
return gossh.PublicKeysCallback(sshAgent.Signers)
}

View File

@ -0,0 +1,160 @@
package communicator
import (
"bytes"
"net"
"os"
"os/exec"
"path/filepath"
"strconv"
"testing"
"golang.org/x/crypto/ssh/agent"
)
// startAgent executes ssh-agent, and returns a Agent interface to it.
func startAgent(t *testing.T) (agent.Agent, func()) {
if testing.Short() {
// ssh-agent is not always available, and the key
// types supported vary by platform.
t.Skip("skipping test due to -short")
}
bin, err := exec.LookPath("ssh-agent")
if err != nil {
t.Skip("could not find ssh-agent")
}
cmd := exec.Command(bin, "-s")
out, err := cmd.Output()
if err != nil {
t.Fatalf("cmd.Output: %v", err)
}
/* Output looks like:
SSH_AUTH_SOCK=/tmp/ssh-P65gpcqArqvH/agent.15541; export SSH_AUTH_SOCK;
SSH_AGENT_PID=15542; export SSH_AGENT_PID;
echo Agent pid 15542;
*/
fields := bytes.Split(out, []byte(";"))
line := bytes.SplitN(fields[0], []byte("="), 2)
line[0] = bytes.TrimLeft(line[0], "\n")
if string(line[0]) != "SSH_AUTH_SOCK" {
t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0])
}
socket := string(line[1])
t.Logf("Socket value: %v", socket)
origSocket := os.Getenv("SSH_AUTH_SOCK")
if err := os.Setenv("SSH_AUTH_SOCK", socket); err != nil {
t.Fatalf("could not set SSH_AUTH_SOCK environment variable: %v", err)
}
line = bytes.SplitN(fields[2], []byte("="), 2)
line[0] = bytes.TrimLeft(line[0], "\n")
if string(line[0]) != "SSH_AGENT_PID" {
t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2])
}
pidStr := line[1]
t.Logf("Agent PID: %v", string(pidStr))
pid, err := strconv.Atoi(string(pidStr))
if err != nil {
t.Fatalf("Atoi(%q): %v", pidStr, err)
}
conn, err := net.Dial("unix", string(socket))
if err != nil {
t.Fatalf("net.Dial: %v", err)
}
return agent.NewClient(conn), func() {
proc, _ := os.FindProcess(pid)
if proc != nil {
proc.Kill()
}
os.Setenv("SSH_AUTH_SOCK", origSocket)
conn.Close()
os.RemoveAll(filepath.Dir(socket))
}
}
func TestSSHAgent(t *testing.T) {
_, cleanup := startAgent(t)
defer cleanup()
if auth := sshAgent(); auth == nil {
t.Error("Want `ssh.AuthMethod`, got `nil`")
}
}
func TestSSHBastionConfig(t *testing.T) {
pemPath := TestPEM(t)
tests := []struct {
in *Config
errStr string
want int
fn func() func()
}{
{
in: &Config{SSHDisableAgent: true},
want: 0,
},
{
in: &Config{SSHDisableAgent: false},
want: 0,
fn: func() func() {
_, cleanup := startAgent(t)
os.Unsetenv("SSH_AUTH_SOCK")
return cleanup
},
},
{
in: &Config{
SSHDisableAgent: false,
SSHBastionPassword: "foobar",
SSHBastionPrivateKey: pemPath,
},
want: 4,
fn: func() func() {
_, cleanup := startAgent(t)
return cleanup
},
},
{
in: &Config{
SSHBastionPrivateKey: pemPath,
},
want: 0,
errStr: "Failed to read key '" + pemPath + "': no key found",
fn: func() func() {
os.Truncate(pemPath, 0)
return func() {
if err := os.Remove(pemPath); err != nil {
t.Fatalf("os.Remove: %v", err)
}
}
},
},
}
for _, c := range tests {
func() {
if c.fn != nil {
defered := c.fn()
defer defered()
}
bConf, err := sshBastionConfig(c.in)
if err != nil {
if err.Error() != c.errStr {
t.Errorf("want error %v, got %q", c.errStr, err)
}
return
}
if len(bConf.Auth) != c.want {
t.Errorf("want %v ssh.AuthMethod, got %v ssh.AuthMethod", c.want, len(bConf.Auth))
}
}()
}
}