Use SSH agent when enabled for bastion step
This commit is contained in:
parent
16e1e488d4
commit
b598baa5e3
|
@ -5,6 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -13,6 +14,7 @@ import (
|
||||||
"github.com/mitchellh/packer/communicator/ssh"
|
"github.com/mitchellh/packer/communicator/ssh"
|
||||||
"github.com/mitchellh/packer/packer"
|
"github.com/mitchellh/packer/packer"
|
||||||
gossh "golang.org/x/crypto/ssh"
|
gossh "golang.org/x/crypto/ssh"
|
||||||
|
"golang.org/x/crypto/ssh/agent"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StepConnectSSH is a step that only connects to SSH.
|
// StepConnectSSH is a step that only connects to SSH.
|
||||||
|
@ -94,6 +96,7 @@ func (s *StepConnectSSH) waitForSSH(state multistep.StateBag, cancel <-chan stru
|
||||||
|
|
||||||
conf, err := sshBastionConfig(s.Config)
|
conf, err := sshBastionConfig(s.Config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Printf("[ERROR] Error calling sshBastionConfig: %v", err)
|
||||||
return nil, fmt.Errorf("Error configuring bastion: %s", err)
|
return nil, fmt.Errorf("Error configuring bastion: %s", err)
|
||||||
}
|
}
|
||||||
bConf = conf
|
bConf = conf
|
||||||
|
@ -196,7 +199,15 @@ func (s *StepConnectSSH) waitForSSH(state multistep.StateBag, cancel <-chan stru
|
||||||
}
|
}
|
||||||
|
|
||||||
func sshBastionConfig(config *Config) (*gossh.ClientConfig, error) {
|
func sshBastionConfig(config *Config) (*gossh.ClientConfig, error) {
|
||||||
auth := make([]gossh.AuthMethod, 0, 2)
|
var auth []gossh.AuthMethod
|
||||||
|
|
||||||
|
if !config.SSHDisableAgent {
|
||||||
|
log.Printf("[INFO] SSH agent forwarding enabled.")
|
||||||
|
if sshAgent := sshAgent(); sshAgent != nil {
|
||||||
|
auth = append(auth, sshAgent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if config.SSHBastionPassword != "" {
|
if config.SSHBastionPassword != "" {
|
||||||
auth = append(auth,
|
auth = append(auth,
|
||||||
gossh.Password(config.SSHBastionPassword),
|
gossh.Password(config.SSHBastionPassword),
|
||||||
|
@ -218,3 +229,21 @@ func sshBastionConfig(config *Config) (*gossh.ClientConfig, error) {
|
||||||
Auth: auth,
|
Auth: auth,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sshAgent() gossh.AuthMethod {
|
||||||
|
socket := os.Getenv("SSH_AUTH_SOCK")
|
||||||
|
if socket == "" {
|
||||||
|
log.Println("[DEBUG] Error fetching SSH_AUTH_SOCK.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
agentConn, err := net.Dial("unix", socket)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[WARN] net.Dial: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("[INFO] Using SSH Agent.")
|
||||||
|
sshAgent := agent.NewClient(agentConn)
|
||||||
|
return gossh.PublicKeysCallback(sshAgent.Signers)
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,160 @@
|
||||||
|
package communicator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh/agent"
|
||||||
|
)
|
||||||
|
|
||||||
|
// startAgent executes ssh-agent, and returns a Agent interface to it.
|
||||||
|
func startAgent(t *testing.T) (agent.Agent, func()) {
|
||||||
|
if testing.Short() {
|
||||||
|
// ssh-agent is not always available, and the key
|
||||||
|
// types supported vary by platform.
|
||||||
|
t.Skip("skipping test due to -short")
|
||||||
|
}
|
||||||
|
|
||||||
|
bin, err := exec.LookPath("ssh-agent")
|
||||||
|
if err != nil {
|
||||||
|
t.Skip("could not find ssh-agent")
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command(bin, "-s")
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("cmd.Output: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Output looks like:
|
||||||
|
SSH_AUTH_SOCK=/tmp/ssh-P65gpcqArqvH/agent.15541; export SSH_AUTH_SOCK;
|
||||||
|
SSH_AGENT_PID=15542; export SSH_AGENT_PID;
|
||||||
|
echo Agent pid 15542;
|
||||||
|
*/
|
||||||
|
fields := bytes.Split(out, []byte(";"))
|
||||||
|
line := bytes.SplitN(fields[0], []byte("="), 2)
|
||||||
|
line[0] = bytes.TrimLeft(line[0], "\n")
|
||||||
|
if string(line[0]) != "SSH_AUTH_SOCK" {
|
||||||
|
t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0])
|
||||||
|
}
|
||||||
|
socket := string(line[1])
|
||||||
|
t.Logf("Socket value: %v", socket)
|
||||||
|
|
||||||
|
origSocket := os.Getenv("SSH_AUTH_SOCK")
|
||||||
|
if err := os.Setenv("SSH_AUTH_SOCK", socket); err != nil {
|
||||||
|
t.Fatalf("could not set SSH_AUTH_SOCK environment variable: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
line = bytes.SplitN(fields[2], []byte("="), 2)
|
||||||
|
line[0] = bytes.TrimLeft(line[0], "\n")
|
||||||
|
if string(line[0]) != "SSH_AGENT_PID" {
|
||||||
|
t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2])
|
||||||
|
}
|
||||||
|
pidStr := line[1]
|
||||||
|
t.Logf("Agent PID: %v", string(pidStr))
|
||||||
|
pid, err := strconv.Atoi(string(pidStr))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Atoi(%q): %v", pidStr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := net.Dial("unix", string(socket))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("net.Dial: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return agent.NewClient(conn), func() {
|
||||||
|
proc, _ := os.FindProcess(pid)
|
||||||
|
if proc != nil {
|
||||||
|
proc.Kill()
|
||||||
|
}
|
||||||
|
|
||||||
|
os.Setenv("SSH_AUTH_SOCK", origSocket)
|
||||||
|
conn.Close()
|
||||||
|
os.RemoveAll(filepath.Dir(socket))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHAgent(t *testing.T) {
|
||||||
|
_, cleanup := startAgent(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
if auth := sshAgent(); auth == nil {
|
||||||
|
t.Error("Want `ssh.AuthMethod`, got `nil`")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHBastionConfig(t *testing.T) {
|
||||||
|
pemPath := TestPEM(t)
|
||||||
|
tests := []struct {
|
||||||
|
in *Config
|
||||||
|
errStr string
|
||||||
|
want int
|
||||||
|
fn func() func()
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
in: &Config{SSHDisableAgent: true},
|
||||||
|
want: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
in: &Config{SSHDisableAgent: false},
|
||||||
|
want: 0,
|
||||||
|
fn: func() func() {
|
||||||
|
_, cleanup := startAgent(t)
|
||||||
|
os.Unsetenv("SSH_AUTH_SOCK")
|
||||||
|
return cleanup
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
in: &Config{
|
||||||
|
SSHDisableAgent: false,
|
||||||
|
SSHBastionPassword: "foobar",
|
||||||
|
SSHBastionPrivateKey: pemPath,
|
||||||
|
},
|
||||||
|
want: 4,
|
||||||
|
fn: func() func() {
|
||||||
|
_, cleanup := startAgent(t)
|
||||||
|
return cleanup
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
in: &Config{
|
||||||
|
SSHBastionPrivateKey: pemPath,
|
||||||
|
},
|
||||||
|
want: 0,
|
||||||
|
errStr: "Failed to read key '" + pemPath + "': no key found",
|
||||||
|
fn: func() func() {
|
||||||
|
os.Truncate(pemPath, 0)
|
||||||
|
return func() {
|
||||||
|
if err := os.Remove(pemPath); err != nil {
|
||||||
|
t.Fatalf("os.Remove: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range tests {
|
||||||
|
func() {
|
||||||
|
if c.fn != nil {
|
||||||
|
defered := c.fn()
|
||||||
|
defer defered()
|
||||||
|
}
|
||||||
|
bConf, err := sshBastionConfig(c.in)
|
||||||
|
if err != nil {
|
||||||
|
if err.Error() != c.errStr {
|
||||||
|
t.Errorf("want error %v, got %q", c.errStr, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(bConf.Auth) != c.want {
|
||||||
|
t.Errorf("want %v ssh.AuthMethod, got %v ssh.AuthMethod", c.want, len(bConf.Auth))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue