Use SSH agent when enabled for bastion step
This commit is contained in:
parent
16e1e488d4
commit
b598baa5e3
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -13,6 +14,7 @@ import (
|
|||
"github.com/mitchellh/packer/communicator/ssh"
|
||||
"github.com/mitchellh/packer/packer"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
)
|
||||
|
||||
// StepConnectSSH is a step that only connects to SSH.
|
||||
|
@ -94,6 +96,7 @@ func (s *StepConnectSSH) waitForSSH(state multistep.StateBag, cancel <-chan stru
|
|||
|
||||
conf, err := sshBastionConfig(s.Config)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] Error calling sshBastionConfig: %v", err)
|
||||
return nil, fmt.Errorf("Error configuring bastion: %s", err)
|
||||
}
|
||||
bConf = conf
|
||||
|
@ -196,7 +199,15 @@ func (s *StepConnectSSH) waitForSSH(state multistep.StateBag, cancel <-chan stru
|
|||
}
|
||||
|
||||
func sshBastionConfig(config *Config) (*gossh.ClientConfig, error) {
|
||||
auth := make([]gossh.AuthMethod, 0, 2)
|
||||
var auth []gossh.AuthMethod
|
||||
|
||||
if !config.SSHDisableAgent {
|
||||
log.Printf("[INFO] SSH agent forwarding enabled.")
|
||||
if sshAgent := sshAgent(); sshAgent != nil {
|
||||
auth = append(auth, sshAgent)
|
||||
}
|
||||
}
|
||||
|
||||
if config.SSHBastionPassword != "" {
|
||||
auth = append(auth,
|
||||
gossh.Password(config.SSHBastionPassword),
|
||||
|
@ -218,3 +229,21 @@ func sshBastionConfig(config *Config) (*gossh.ClientConfig, error) {
|
|||
Auth: auth,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func sshAgent() gossh.AuthMethod {
|
||||
socket := os.Getenv("SSH_AUTH_SOCK")
|
||||
if socket == "" {
|
||||
log.Println("[DEBUG] Error fetching SSH_AUTH_SOCK.")
|
||||
return nil
|
||||
}
|
||||
|
||||
agentConn, err := net.Dial("unix", socket)
|
||||
if err != nil {
|
||||
log.Printf("[WARN] net.Dial: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Println("[INFO] Using SSH Agent.")
|
||||
sshAgent := agent.NewClient(agentConn)
|
||||
return gossh.PublicKeysCallback(sshAgent.Signers)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,160 @@
|
|||
package communicator
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
)
|
||||
|
||||
// startAgent executes ssh-agent, and returns a Agent interface to it.
|
||||
func startAgent(t *testing.T) (agent.Agent, func()) {
|
||||
if testing.Short() {
|
||||
// ssh-agent is not always available, and the key
|
||||
// types supported vary by platform.
|
||||
t.Skip("skipping test due to -short")
|
||||
}
|
||||
|
||||
bin, err := exec.LookPath("ssh-agent")
|
||||
if err != nil {
|
||||
t.Skip("could not find ssh-agent")
|
||||
}
|
||||
|
||||
cmd := exec.Command(bin, "-s")
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
t.Fatalf("cmd.Output: %v", err)
|
||||
}
|
||||
|
||||
/* Output looks like:
|
||||
SSH_AUTH_SOCK=/tmp/ssh-P65gpcqArqvH/agent.15541; export SSH_AUTH_SOCK;
|
||||
SSH_AGENT_PID=15542; export SSH_AGENT_PID;
|
||||
echo Agent pid 15542;
|
||||
*/
|
||||
fields := bytes.Split(out, []byte(";"))
|
||||
line := bytes.SplitN(fields[0], []byte("="), 2)
|
||||
line[0] = bytes.TrimLeft(line[0], "\n")
|
||||
if string(line[0]) != "SSH_AUTH_SOCK" {
|
||||
t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0])
|
||||
}
|
||||
socket := string(line[1])
|
||||
t.Logf("Socket value: %v", socket)
|
||||
|
||||
origSocket := os.Getenv("SSH_AUTH_SOCK")
|
||||
if err := os.Setenv("SSH_AUTH_SOCK", socket); err != nil {
|
||||
t.Fatalf("could not set SSH_AUTH_SOCK environment variable: %v", err)
|
||||
}
|
||||
|
||||
line = bytes.SplitN(fields[2], []byte("="), 2)
|
||||
line[0] = bytes.TrimLeft(line[0], "\n")
|
||||
if string(line[0]) != "SSH_AGENT_PID" {
|
||||
t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2])
|
||||
}
|
||||
pidStr := line[1]
|
||||
t.Logf("Agent PID: %v", string(pidStr))
|
||||
pid, err := strconv.Atoi(string(pidStr))
|
||||
if err != nil {
|
||||
t.Fatalf("Atoi(%q): %v", pidStr, err)
|
||||
}
|
||||
|
||||
conn, err := net.Dial("unix", string(socket))
|
||||
if err != nil {
|
||||
t.Fatalf("net.Dial: %v", err)
|
||||
}
|
||||
|
||||
return agent.NewClient(conn), func() {
|
||||
proc, _ := os.FindProcess(pid)
|
||||
if proc != nil {
|
||||
proc.Kill()
|
||||
}
|
||||
|
||||
os.Setenv("SSH_AUTH_SOCK", origSocket)
|
||||
conn.Close()
|
||||
os.RemoveAll(filepath.Dir(socket))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHAgent(t *testing.T) {
|
||||
_, cleanup := startAgent(t)
|
||||
defer cleanup()
|
||||
|
||||
if auth := sshAgent(); auth == nil {
|
||||
t.Error("Want `ssh.AuthMethod`, got `nil`")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHBastionConfig(t *testing.T) {
|
||||
pemPath := TestPEM(t)
|
||||
tests := []struct {
|
||||
in *Config
|
||||
errStr string
|
||||
want int
|
||||
fn func() func()
|
||||
}{
|
||||
{
|
||||
in: &Config{SSHDisableAgent: true},
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
in: &Config{SSHDisableAgent: false},
|
||||
want: 0,
|
||||
fn: func() func() {
|
||||
_, cleanup := startAgent(t)
|
||||
os.Unsetenv("SSH_AUTH_SOCK")
|
||||
return cleanup
|
||||
},
|
||||
},
|
||||
{
|
||||
in: &Config{
|
||||
SSHDisableAgent: false,
|
||||
SSHBastionPassword: "foobar",
|
||||
SSHBastionPrivateKey: pemPath,
|
||||
},
|
||||
want: 4,
|
||||
fn: func() func() {
|
||||
_, cleanup := startAgent(t)
|
||||
return cleanup
|
||||
},
|
||||
},
|
||||
{
|
||||
in: &Config{
|
||||
SSHBastionPrivateKey: pemPath,
|
||||
},
|
||||
want: 0,
|
||||
errStr: "Failed to read key '" + pemPath + "': no key found",
|
||||
fn: func() func() {
|
||||
os.Truncate(pemPath, 0)
|
||||
return func() {
|
||||
if err := os.Remove(pemPath); err != nil {
|
||||
t.Fatalf("os.Remove: %v", err)
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range tests {
|
||||
func() {
|
||||
if c.fn != nil {
|
||||
defered := c.fn()
|
||||
defer defered()
|
||||
}
|
||||
bConf, err := sshBastionConfig(c.in)
|
||||
if err != nil {
|
||||
if err.Error() != c.errStr {
|
||||
t.Errorf("want error %v, got %q", c.errStr, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if len(bConf.Auth) != c.want {
|
||||
t.Errorf("want %v ssh.AuthMethod, got %v ssh.AuthMethod", c.want, len(bConf.Auth))
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue