From 70e3f6053950b37358e57f2d7c4e80bfc747dafc Mon Sep 17 00:00:00 2001 From: r_takaishi Date: Fri, 13 Mar 2020 14:01:11 +0900 Subject: [PATCH] write unit test --- communicator/ssh/keyboard_interactive.go | 18 +--- communicator/ssh/keyboard_interactive_test.go | 89 +++++++++++++++++++ helper/communicator/step_connect_ssh.go | 15 +++- 3 files changed, 107 insertions(+), 15 deletions(-) create mode 100644 communicator/ssh/keyboard_interactive_test.go diff --git a/communicator/ssh/keyboard_interactive.go b/communicator/ssh/keyboard_interactive.go index 4da16d733..417ab00ae 100644 --- a/communicator/ssh/keyboard_interactive.go +++ b/communicator/ssh/keyboard_interactive.go @@ -1,14 +1,15 @@ package ssh import ( + "io" "log" - "os" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/terminal" ) -func KeyboardInteractive() ssh.KeyboardInteractiveChallenge { +func KeyboardInteractive(c io.ReadWriter) ssh.KeyboardInteractiveChallenge { + t := terminal.NewTerminal(c, "") return func(user, instruction string, questions []string, echos []bool) ([]string, error) { if len(questions) == 0 { return []string{}, nil @@ -21,18 +22,7 @@ func KeyboardInteractive() ssh.KeyboardInteractiveChallenge { } answers := make([]string, len(questions)) for i := range questions { - var fd int - if terminal.IsTerminal(int(os.Stdin.Fd())) { - fd = int(os.Stdin.Fd()) - } else { - tty, err := os.Open("/dev/tty") - if err != nil { - return nil, err - } - defer tty.Close() - fd = int(tty.Fd()) - } - s, err := terminal.ReadPassword(fd) + s, err := t.ReadPassword("") if err != nil { return nil, err } diff --git a/communicator/ssh/keyboard_interactive_test.go b/communicator/ssh/keyboard_interactive_test.go new file mode 100644 index 000000000..fc1bbfd45 --- /dev/null +++ b/communicator/ssh/keyboard_interactive_test.go @@ -0,0 +1,89 @@ +package ssh + +import ( + "io" + "log" + "reflect" + "testing" +) + +type MockTerminal struct { + toSend []byte + bytesPerRead int + received []byte +} + +func (c *MockTerminal) Read(data []byte) (n int, err error) { + n = len(data) + if n == 0 { + return + } + if n > len(c.toSend) { + n = len(c.toSend) + } + if n == 0 { + return 0, io.EOF + } + if c.bytesPerRead > 0 && n > c.bytesPerRead { + n = c.bytesPerRead + } + copy(data, c.toSend[:n]) + c.toSend = c.toSend[n:] + return +} + +func (c *MockTerminal) Write(data []byte) (n int, err error) { + c.received = append(c.received, data...) + return len(data), nil +} + +func TestKeyboardInteractive(t *testing.T) { + type args struct { + user string + instruction string + questions []string + echos []bool + } + tests := []struct { + name string + args args + want []string + wantErr bool + }{ + { + name: "questions are none", + args: args{ + questions: []string{}, + }, + want: []string{}, + wantErr: false, + }, + { + name: "input answer interactive", + args: args{ + questions: []string{"this is question"}, + }, + want: []string{"xxxx"}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &MockTerminal{ + toSend: []byte("xxxx\r\x1b[A\r"), + bytesPerRead: 1, + } + f := KeyboardInteractive(c) + got, err := f(tt.args.user, tt.args.instruction, tt.args.questions, tt.args.echos) + + if (err != nil) != tt.wantErr { + t.Errorf("KeyboardInteractive error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("KeyboardInteractive = %v, want %v", got, tt.want) + } + log.Printf("finish") + }) + } +} diff --git a/helper/communicator/step_connect_ssh.go b/helper/communicator/step_connect_ssh.go index 8eb1c5156..8ad82c23e 100644 --- a/helper/communicator/step_connect_ssh.go +++ b/helper/communicator/step_connect_ssh.go @@ -4,6 +4,8 @@ import ( "context" "errors" "fmt" + "golang.org/x/crypto/ssh/terminal" + "io" "log" "net" "os" @@ -247,7 +249,18 @@ func sshBastionConfig(config *Config) (*gossh.ClientConfig, error) { auth := make([]gossh.AuthMethod, 0, 2) if config.SSHBastionInteractive { - auth = append(auth, gossh.KeyboardInteractive(ssh.KeyboardInteractive())) + var c io.ReadWriteCloser + if terminal.IsTerminal(int(os.Stdin.Fd())) { + c = os.Stdin + } else { + tty, err := os.Open("/dev/tty") + if err != nil { + return nil, err + } + defer tty.Close() + c = tty + } + auth = append(auth, gossh.KeyboardInteractive(ssh.KeyboardInteractive(c))) } if config.SSHBastionPassword != "" {