packer-cn/common/adapter/adapter.go
Adrien Delorme f555e7a9f2 allow a provisioner to timeout
* I had to contextualise Communicator.Start and RemoteCmd.StartWithUi
NOTE: Communicator.Start starts a RemoteCmd but RemoteCmd.StartWithUi will run the cmd and wait for a return, so I renamed StartWithUi to RunWithUi so that the intent is clearer.
Ideally in the future RunWithUi will be named back to StartWithUi and the exit status or wait funcs of the command will allow to wait for a return. If you do so please read carrefully https://golang.org/pkg/os/exec/#Cmd.Stdout to avoid a deadlock
* cmd.ExitStatus to cmd.ExitStatus() is now blocking to avoid race conditions
* also had to simplify StartWithUi
2019-04-08 20:09:21 +02:00

339 lines
7.3 KiB
Go

package adapter
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"log"
"net"
"strings"
"github.com/google/shlex"
"github.com/hashicorp/packer/packer"
"golang.org/x/crypto/ssh"
)
// An adapter satisfies SSH requests (from an Ansible client) by delegating SSH
// exec and subsystem commands to a packer.Communicator.
type Adapter struct {
done <-chan struct{}
l net.Listener
config *ssh.ServerConfig
sftpCmd string
ui packer.Ui
comm packer.Communicator
}
func NewAdapter(done <-chan struct{}, l net.Listener, config *ssh.ServerConfig, sftpCmd string, ui packer.Ui, comm packer.Communicator) *Adapter {
return &Adapter{
done: done,
l: l,
config: config,
sftpCmd: sftpCmd,
ui: ui,
comm: comm,
}
}
func (c *Adapter) Serve() {
log.Printf("SSH proxy: serving on %s", c.l.Addr())
for {
// Accept will return if either the underlying connection is closed or if a connection is made.
// after returning, check to see if c.done can be received. If so, then Accept() returned because
// the connection has been closed.
conn, err := c.l.Accept()
select {
case <-c.done:
return
default:
if err != nil {
c.ui.Error(fmt.Sprintf("listen.Accept failed: %v", err))
continue
}
go func(conn net.Conn) {
if err := c.Handle(conn, c.ui); err != nil {
c.ui.Error(err.Error())
}
}(conn)
}
}
}
func (c *Adapter) Handle(conn net.Conn, ui packer.Ui) error {
log.Print("SSH proxy: accepted connection")
_, chans, reqs, err := ssh.NewServerConn(conn, c.config)
if err != nil {
return errors.New("failed to handshake")
}
// discard all global requests
go ssh.DiscardRequests(reqs)
// Service the incoming NewChannels
for newChannel := range chans {
if newChannel.ChannelType() != "session" {
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
continue
}
go func(ch ssh.NewChannel) {
if err := c.handleSession(ch); err != nil {
c.ui.Error(err.Error())
}
}(newChannel)
}
return nil
}
func (c *Adapter) handleSession(newChannel ssh.NewChannel) error {
channel, requests, err := newChannel.Accept()
if err != nil {
return err
}
defer channel.Close()
done := make(chan struct{})
// Sessions have requests such as "pty-req", "shell", "env", and "exec".
// see RFC 4254, section 6
go func(in <-chan *ssh.Request) {
env := make([]envRequestPayload, 4)
for req := range in {
switch req.Type {
case "pty-req":
log.Println("ansible provisioner pty-req request")
// accept pty-req requests, but don't actually do anything. Necessary for OpenSSH and sudo.
req.Reply(true, nil)
case "env":
req, err := newEnvRequest(req)
if err != nil {
c.ui.Error(err.Error())
req.Reply(false, nil)
continue
}
env = append(env, req.Payload)
log.Printf("new env request: %s", req.Payload)
req.Reply(true, nil)
case "exec":
req, err := newExecRequest(req)
if err != nil {
c.ui.Error(err.Error())
req.Reply(false, nil)
close(done)
continue
}
log.Printf("new exec request: %s", req.Payload)
if len(req.Payload) == 0 {
req.Reply(false, nil)
close(done)
return
}
go func(channel ssh.Channel) {
exit := c.exec(string(req.Payload), channel, channel, channel.Stderr())
exitStatus := make([]byte, 4)
binary.BigEndian.PutUint32(exitStatus, uint32(exit))
channel.SendRequest("exit-status", false, exitStatus)
close(done)
}(channel)
req.Reply(true, nil)
case "subsystem":
req, err := newSubsystemRequest(req)
if err != nil {
c.ui.Error(err.Error())
req.Reply(false, nil)
continue
}
log.Printf("new subsystem request: %s", req.Payload)
switch req.Payload {
case "sftp":
sftpCmd := c.sftpCmd
if len(sftpCmd) == 0 {
sftpCmd = "/usr/lib/sftp-server -e"
}
log.Print("starting sftp subsystem")
go func() {
_ = c.remoteExec(sftpCmd, channel, channel, channel.Stderr())
close(done)
}()
req.Reply(true, nil)
default:
c.ui.Error(fmt.Sprintf("unsupported subsystem requested: %s", req.Payload))
req.Reply(false, nil)
}
default:
log.Printf("rejecting %s request", req.Type)
req.Reply(false, nil)
}
}
}(requests)
<-done
return nil
}
func (c *Adapter) Shutdown() {
c.l.Close()
}
func (c *Adapter) exec(command string, in io.Reader, out io.Writer, err io.Writer) int {
var exitStatus int
switch {
case strings.HasPrefix(command, "scp ") && serveSCP(command[4:]):
err := c.scpExec(command[4:], in, out)
if err != nil {
log.Println(err)
exitStatus = 1
}
default:
exitStatus = c.remoteExec(command, in, out, err)
}
return exitStatus
}
func serveSCP(args string) bool {
opts, _ := scpOptions(args)
return bytes.IndexAny(opts, "tf") >= 0
}
func (c *Adapter) scpExec(args string, in io.Reader, out io.Writer) error {
opts, rest := scpOptions(args)
// remove the quoting that ansible added to rest for shell safety.
shargs, err := shlex.Split(rest)
if err != nil {
return err
}
rest = strings.Join(shargs, "")
if i := bytes.IndexByte(opts, 't'); i >= 0 {
return scpUploadSession(opts, rest, in, out, c.comm)
}
if i := bytes.IndexByte(opts, 'f'); i >= 0 {
return scpDownloadSession(opts, rest, in, out, c.comm)
}
return errors.New("no scp mode specified")
}
func (c *Adapter) remoteExec(command string, in io.Reader, out io.Writer, err io.Writer) int {
cmd := &packer.RemoteCmd{
Stdin: in,
Stdout: out,
Stderr: err,
Command: command,
}
ctx := context.TODO()
if err := c.comm.Start(ctx, cmd); err != nil {
c.ui.Error(err.Error())
}
cmd.Wait()
return cmd.ExitStatus()
}
type envRequest struct {
*ssh.Request
Payload envRequestPayload
}
type envRequestPayload struct {
Name string
Value string
}
func (p envRequestPayload) String() string {
return fmt.Sprintf("%s=%s", p.Name, p.Value)
}
func newEnvRequest(raw *ssh.Request) (*envRequest, error) {
r := new(envRequest)
r.Request = raw
if err := ssh.Unmarshal(raw.Payload, &r.Payload); err != nil {
return nil, err
}
return r, nil
}
func sshString(buf io.Reader) (string, error) {
var size uint32
err := binary.Read(buf, binary.BigEndian, &size)
if err != nil {
return "", err
}
b := make([]byte, size)
err = binary.Read(buf, binary.BigEndian, b)
if err != nil {
return "", err
}
return string(b), nil
}
type execRequest struct {
*ssh.Request
Payload execRequestPayload
}
type execRequestPayload string
func (p execRequestPayload) String() string {
return string(p)
}
func newExecRequest(raw *ssh.Request) (*execRequest, error) {
r := new(execRequest)
r.Request = raw
buf := bytes.NewReader(r.Request.Payload)
var err error
var payload string
if payload, err = sshString(buf); err != nil {
return nil, err
}
r.Payload = execRequestPayload(payload)
return r, nil
}
type subsystemRequest struct {
*ssh.Request
Payload subsystemRequestPayload
}
type subsystemRequestPayload string
func (p subsystemRequestPayload) String() string {
return string(p)
}
func newSubsystemRequest(raw *ssh.Request) (*subsystemRequest, error) {
r := new(subsystemRequest)
r.Request = raw
buf := bytes.NewReader(r.Request.Payload)
var err error
var payload string
if payload, err = sshString(buf); err != nil {
return nil, err
}
r.Payload = subsystemRequestPayload(payload)
return r, nil
}