339 lines
7.3 KiB
Go
339 lines
7.3 KiB
Go
package adapter
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"strings"
|
|
|
|
"github.com/google/shlex"
|
|
packersdk "github.com/hashicorp/packer/packer-plugin-sdk/packer"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
// An adapter satisfies SSH requests (from an Ansible client) by delegating SSH
|
|
// exec and subsystem commands to a packersdk.Communicator.
|
|
type Adapter struct {
|
|
done <-chan struct{}
|
|
l net.Listener
|
|
config *ssh.ServerConfig
|
|
sftpCmd string
|
|
ui packersdk.Ui
|
|
comm packersdk.Communicator
|
|
}
|
|
|
|
func NewAdapter(done <-chan struct{}, l net.Listener, config *ssh.ServerConfig, sftpCmd string, ui packersdk.Ui, comm packersdk.Communicator) *Adapter {
|
|
return &Adapter{
|
|
done: done,
|
|
l: l,
|
|
config: config,
|
|
sftpCmd: sftpCmd,
|
|
ui: ui,
|
|
comm: comm,
|
|
}
|
|
}
|
|
|
|
func (c *Adapter) Serve() {
|
|
log.Printf("SSH proxy: serving on %s", c.l.Addr())
|
|
|
|
for {
|
|
// Accept will return if either the underlying connection is closed or if a connection is made.
|
|
// after returning, check to see if c.done can be received. If so, then Accept() returned because
|
|
// the connection has been closed.
|
|
conn, err := c.l.Accept()
|
|
select {
|
|
case <-c.done:
|
|
return
|
|
default:
|
|
if err != nil {
|
|
c.ui.Error(fmt.Sprintf("listen.Accept failed: %v", err))
|
|
continue
|
|
}
|
|
go func(conn net.Conn) {
|
|
if err := c.Handle(conn, c.ui); err != nil {
|
|
c.ui.Error(err.Error())
|
|
}
|
|
}(conn)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Adapter) Handle(conn net.Conn, ui packersdk.Ui) error {
|
|
log.Print("SSH proxy: accepted connection")
|
|
_, chans, reqs, err := ssh.NewServerConn(conn, c.config)
|
|
if err != nil {
|
|
return errors.New("failed to handshake")
|
|
}
|
|
|
|
// discard all global requests
|
|
go ssh.DiscardRequests(reqs)
|
|
|
|
// Service the incoming NewChannels
|
|
for newChannel := range chans {
|
|
if newChannel.ChannelType() != "session" {
|
|
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
|
|
continue
|
|
}
|
|
|
|
go func(ch ssh.NewChannel) {
|
|
if err := c.handleSession(ch); err != nil {
|
|
c.ui.Error(err.Error())
|
|
}
|
|
}(newChannel)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Adapter) handleSession(newChannel ssh.NewChannel) error {
|
|
channel, requests, err := newChannel.Accept()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer channel.Close()
|
|
|
|
done := make(chan struct{})
|
|
|
|
// Sessions have requests such as "pty-req", "shell", "env", and "exec".
|
|
// see RFC 4254, section 6
|
|
go func(in <-chan *ssh.Request) {
|
|
env := make([]envRequestPayload, 4)
|
|
for req := range in {
|
|
switch req.Type {
|
|
case "pty-req":
|
|
log.Println("ansible provisioner pty-req request")
|
|
// accept pty-req requests, but don't actually do anything. Necessary for OpenSSH and sudo.
|
|
req.Reply(true, nil)
|
|
|
|
case "env":
|
|
req, err := newEnvRequest(req)
|
|
if err != nil {
|
|
c.ui.Error(err.Error())
|
|
req.Reply(false, nil)
|
|
continue
|
|
}
|
|
env = append(env, req.Payload)
|
|
log.Printf("new env request: %s", req.Payload)
|
|
req.Reply(true, nil)
|
|
case "exec":
|
|
req, err := newExecRequest(req)
|
|
if err != nil {
|
|
c.ui.Error(err.Error())
|
|
req.Reply(false, nil)
|
|
close(done)
|
|
continue
|
|
}
|
|
|
|
log.Printf("new exec request: %s", req.Payload)
|
|
|
|
if len(req.Payload) == 0 {
|
|
req.Reply(false, nil)
|
|
close(done)
|
|
return
|
|
}
|
|
|
|
go func(channel ssh.Channel) {
|
|
exit := c.exec(string(req.Payload), channel, channel, channel.Stderr())
|
|
|
|
exitStatus := make([]byte, 4)
|
|
binary.BigEndian.PutUint32(exitStatus, uint32(exit))
|
|
channel.SendRequest("exit-status", false, exitStatus)
|
|
close(done)
|
|
}(channel)
|
|
req.Reply(true, nil)
|
|
case "subsystem":
|
|
req, err := newSubsystemRequest(req)
|
|
if err != nil {
|
|
c.ui.Error(err.Error())
|
|
req.Reply(false, nil)
|
|
continue
|
|
}
|
|
|
|
log.Printf("new subsystem request: %s", req.Payload)
|
|
switch req.Payload {
|
|
case "sftp":
|
|
sftpCmd := c.sftpCmd
|
|
if len(sftpCmd) == 0 {
|
|
sftpCmd = "/usr/lib/sftp-server -e"
|
|
}
|
|
|
|
log.Print("starting sftp subsystem")
|
|
go func() {
|
|
_ = c.remoteExec(sftpCmd, channel, channel, channel.Stderr())
|
|
close(done)
|
|
}()
|
|
req.Reply(true, nil)
|
|
default:
|
|
c.ui.Error(fmt.Sprintf("unsupported subsystem requested: %s", req.Payload))
|
|
req.Reply(false, nil)
|
|
}
|
|
default:
|
|
log.Printf("rejecting %s request", req.Type)
|
|
req.Reply(false, nil)
|
|
}
|
|
}
|
|
}(requests)
|
|
|
|
<-done
|
|
return nil
|
|
}
|
|
|
|
func (c *Adapter) Shutdown() {
|
|
c.l.Close()
|
|
}
|
|
|
|
func (c *Adapter) exec(command string, in io.Reader, out io.Writer, err io.Writer) int {
|
|
var exitStatus int
|
|
switch {
|
|
case strings.HasPrefix(command, "scp ") && serveSCP(command[4:]):
|
|
err := c.scpExec(command[4:], in, out)
|
|
if err != nil {
|
|
log.Println(err)
|
|
exitStatus = 1
|
|
}
|
|
default:
|
|
exitStatus = c.remoteExec(command, in, out, err)
|
|
}
|
|
return exitStatus
|
|
}
|
|
|
|
func serveSCP(args string) bool {
|
|
opts, _ := scpOptions(args)
|
|
return bytes.IndexAny(opts, "tf") >= 0
|
|
}
|
|
|
|
func (c *Adapter) scpExec(args string, in io.Reader, out io.Writer) error {
|
|
opts, rest := scpOptions(args)
|
|
|
|
// remove the quoting that ansible added to rest for shell safety.
|
|
shargs, err := shlex.Split(rest)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rest = strings.Join(shargs, "")
|
|
|
|
if i := bytes.IndexByte(opts, 't'); i >= 0 {
|
|
return scpUploadSession(opts, rest, in, out, c.comm)
|
|
}
|
|
|
|
if i := bytes.IndexByte(opts, 'f'); i >= 0 {
|
|
return scpDownloadSession(opts, rest, in, out, c.comm)
|
|
}
|
|
return errors.New("no scp mode specified")
|
|
}
|
|
|
|
func (c *Adapter) remoteExec(command string, in io.Reader, out io.Writer, err io.Writer) int {
|
|
cmd := &packersdk.RemoteCmd{
|
|
Stdin: in,
|
|
Stdout: out,
|
|
Stderr: err,
|
|
Command: command,
|
|
}
|
|
ctx := context.TODO()
|
|
|
|
if err := c.comm.Start(ctx, cmd); err != nil {
|
|
c.ui.Error(err.Error())
|
|
}
|
|
|
|
cmd.Wait()
|
|
|
|
return cmd.ExitStatus()
|
|
}
|
|
|
|
type envRequest struct {
|
|
*ssh.Request
|
|
Payload envRequestPayload
|
|
}
|
|
|
|
type envRequestPayload struct {
|
|
Name string
|
|
Value string
|
|
}
|
|
|
|
func (p envRequestPayload) String() string {
|
|
return fmt.Sprintf("%s=%s", p.Name, p.Value)
|
|
}
|
|
|
|
func newEnvRequest(raw *ssh.Request) (*envRequest, error) {
|
|
r := new(envRequest)
|
|
r.Request = raw
|
|
|
|
if err := ssh.Unmarshal(raw.Payload, &r.Payload); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return r, nil
|
|
}
|
|
|
|
func sshString(buf io.Reader) (string, error) {
|
|
var size uint32
|
|
err := binary.Read(buf, binary.BigEndian, &size)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
b := make([]byte, size)
|
|
err = binary.Read(buf, binary.BigEndian, b)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return string(b), nil
|
|
}
|
|
|
|
type execRequest struct {
|
|
*ssh.Request
|
|
Payload execRequestPayload
|
|
}
|
|
|
|
type execRequestPayload string
|
|
|
|
func (p execRequestPayload) String() string {
|
|
return string(p)
|
|
}
|
|
|
|
func newExecRequest(raw *ssh.Request) (*execRequest, error) {
|
|
r := new(execRequest)
|
|
r.Request = raw
|
|
buf := bytes.NewReader(r.Request.Payload)
|
|
|
|
var err error
|
|
var payload string
|
|
if payload, err = sshString(buf); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
r.Payload = execRequestPayload(payload)
|
|
return r, nil
|
|
}
|
|
|
|
type subsystemRequest struct {
|
|
*ssh.Request
|
|
Payload subsystemRequestPayload
|
|
}
|
|
|
|
type subsystemRequestPayload string
|
|
|
|
func (p subsystemRequestPayload) String() string {
|
|
return string(p)
|
|
}
|
|
|
|
func newSubsystemRequest(raw *ssh.Request) (*subsystemRequest, error) {
|
|
r := new(subsystemRequest)
|
|
r.Request = raw
|
|
buf := bytes.NewReader(r.Request.Payload)
|
|
|
|
var err error
|
|
var payload string
|
|
if payload, err = sshString(buf); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
r.Payload = subsystemRequestPayload(payload)
|
|
return r, nil
|
|
}
|