write unit test

This commit is contained in:
r_takaishi 2020-03-13 14:01:11 +09:00
parent f50ff1d270
commit 70e3f60539
3 changed files with 107 additions and 15 deletions

View File

@ -1,14 +1,15 @@
package ssh
import (
"io"
"log"
"os"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal"
)
func KeyboardInteractive() ssh.KeyboardInteractiveChallenge {
func KeyboardInteractive(c io.ReadWriter) ssh.KeyboardInteractiveChallenge {
t := terminal.NewTerminal(c, "")
return func(user, instruction string, questions []string, echos []bool) ([]string, error) {
if len(questions) == 0 {
return []string{}, nil
@ -21,18 +22,7 @@ func KeyboardInteractive() ssh.KeyboardInteractiveChallenge {
}
answers := make([]string, len(questions))
for i := range questions {
var fd int
if terminal.IsTerminal(int(os.Stdin.Fd())) {
fd = int(os.Stdin.Fd())
} else {
tty, err := os.Open("/dev/tty")
if err != nil {
return nil, err
}
defer tty.Close()
fd = int(tty.Fd())
}
s, err := terminal.ReadPassword(fd)
s, err := t.ReadPassword("")
if err != nil {
return nil, err
}

View File

@ -0,0 +1,89 @@
package ssh
import (
"io"
"log"
"reflect"
"testing"
)
type MockTerminal struct {
toSend []byte
bytesPerRead int
received []byte
}
func (c *MockTerminal) Read(data []byte) (n int, err error) {
n = len(data)
if n == 0 {
return
}
if n > len(c.toSend) {
n = len(c.toSend)
}
if n == 0 {
return 0, io.EOF
}
if c.bytesPerRead > 0 && n > c.bytesPerRead {
n = c.bytesPerRead
}
copy(data, c.toSend[:n])
c.toSend = c.toSend[n:]
return
}
func (c *MockTerminal) Write(data []byte) (n int, err error) {
c.received = append(c.received, data...)
return len(data), nil
}
func TestKeyboardInteractive(t *testing.T) {
type args struct {
user string
instruction string
questions []string
echos []bool
}
tests := []struct {
name string
args args
want []string
wantErr bool
}{
{
name: "questions are none",
args: args{
questions: []string{},
},
want: []string{},
wantErr: false,
},
{
name: "input answer interactive",
args: args{
questions: []string{"this is question"},
},
want: []string{"xxxx"},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &MockTerminal{
toSend: []byte("xxxx\r\x1b[A\r"),
bytesPerRead: 1,
}
f := KeyboardInteractive(c)
got, err := f(tt.args.user, tt.args.instruction, tt.args.questions, tt.args.echos)
if (err != nil) != tt.wantErr {
t.Errorf("KeyboardInteractive error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("KeyboardInteractive = %v, want %v", got, tt.want)
}
log.Printf("finish")
})
}
}

View File

@ -4,6 +4,8 @@ import (
"context"
"errors"
"fmt"
"golang.org/x/crypto/ssh/terminal"
"io"
"log"
"net"
"os"
@ -247,7 +249,18 @@ func sshBastionConfig(config *Config) (*gossh.ClientConfig, error) {
auth := make([]gossh.AuthMethod, 0, 2)
if config.SSHBastionInteractive {
auth = append(auth, gossh.KeyboardInteractive(ssh.KeyboardInteractive()))
var c io.ReadWriteCloser
if terminal.IsTerminal(int(os.Stdin.Fd())) {
c = os.Stdin
} else {
tty, err := os.Open("/dev/tty")
if err != nil {
return nil, err
}
defer tty.Close()
c = tty
}
auth = append(auth, gossh.KeyboardInteractive(ssh.KeyboardInteractive(c)))
}
if config.SSHBastionPassword != "" {