packer/plugin: connect stdin to parent stdin

This commit is contained in:
Mitchell Hashimoto 2013-07-25 21:24:49 -05:00
parent ce5849308b
commit 9289df6d35
3 changed files with 64 additions and 0 deletions

View File

@ -216,6 +216,7 @@ func (c *Client) Start() (address string, err error) {
cmd := c.config.Cmd cmd := c.config.Cmd
cmd.Env = append(cmd.Env, os.Environ()...) cmd.Env = append(cmd.Env, os.Environ()...)
cmd.Env = append(cmd.Env, env...) cmd.Env = append(cmd.Env, env...)
cmd.Stdin = os.Stdin
cmd.Stderr = stderr cmd.Stderr = stderr
cmd.Stdout = stdout cmd.Stdout = stdout

View File

@ -1,6 +1,8 @@
package plugin package plugin
import ( import (
"io/ioutil"
"os"
"testing" "testing"
"time" "time"
) )
@ -47,3 +49,50 @@ func TestClient_Start_Timeout(t *testing.T) {
t.Fatal("err should not be nil") t.Fatal("err should not be nil")
} }
} }
func TestClient_Stdin(t *testing.T) {
// Overwrite stdin for this test with a temporary file
tf, err := ioutil.TempFile("", "packer")
if err != nil {
t.Fatalf("err: %s", err)
}
defer os.Remove(tf.Name())
defer tf.Close()
if _, err = tf.WriteString("hello"); err != nil {
t.Fatalf("error: %s", err)
}
if err = tf.Sync(); err != nil {
t.Fatalf("error: %s", err)
}
if _, err = tf.Seek(0, 0); err != nil {
t.Fatalf("error: %s", err)
}
oldStdin := os.Stdin
defer func() { os.Stdin = oldStdin }()
os.Stdin = tf
process := helperProcess("stdin")
c := NewClient(&ClientConfig{Cmd: process})
defer c.Kill()
_, err = c.Start()
if err != nil {
t.Fatalf("error: %s", err)
}
for {
if c.Exited() {
break
}
time.Sleep(50 * time.Millisecond)
}
if !process.ProcessState.Success() {
t.Fatal("process didn't exit cleanly")
}
}

View File

@ -2,6 +2,7 @@ package plugin
import ( import (
"fmt" "fmt"
"log"
"os" "os"
"os/exec" "os/exec"
"testing" "testing"
@ -65,6 +66,19 @@ func TestHelperProcess(*testing.T) {
ServeProvisioner(new(helperProvisioner)) ServeProvisioner(new(helperProvisioner))
case "start-timeout": case "start-timeout":
time.Sleep(1 * time.Minute) time.Sleep(1 * time.Minute)
os.Exit(1)
case "stdin":
fmt.Println(":1234")
data := make([]byte, 5)
if _, err := os.Stdin.Read(data); err != nil {
log.Printf("stdin read error: %s", err)
os.Exit(100)
}
if string(data) == "hello" {
os.Exit(0)
}
os.Exit(1) os.Exit(1)
default: default:
fmt.Fprintf(os.Stderr, "Unknown command: %q\n", cmd) fmt.Fprintf(os.Stderr, "Unknown command: %q\n", cmd)