write unit test
This commit is contained in:
parent
f50ff1d270
commit
70e3f60539
|
@ -1,14 +1,15 @@
|
|||
package ssh
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/terminal"
|
||||
)
|
||||
|
||||
func KeyboardInteractive() ssh.KeyboardInteractiveChallenge {
|
||||
func KeyboardInteractive(c io.ReadWriter) ssh.KeyboardInteractiveChallenge {
|
||||
t := terminal.NewTerminal(c, "")
|
||||
return func(user, instruction string, questions []string, echos []bool) ([]string, error) {
|
||||
if len(questions) == 0 {
|
||||
return []string{}, nil
|
||||
|
@ -21,18 +22,7 @@ func KeyboardInteractive() ssh.KeyboardInteractiveChallenge {
|
|||
}
|
||||
answers := make([]string, len(questions))
|
||||
for i := range questions {
|
||||
var fd int
|
||||
if terminal.IsTerminal(int(os.Stdin.Fd())) {
|
||||
fd = int(os.Stdin.Fd())
|
||||
} else {
|
||||
tty, err := os.Open("/dev/tty")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tty.Close()
|
||||
fd = int(tty.Fd())
|
||||
}
|
||||
s, err := terminal.ReadPassword(fd)
|
||||
s, err := t.ReadPassword("")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
package ssh
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type MockTerminal struct {
|
||||
toSend []byte
|
||||
bytesPerRead int
|
||||
received []byte
|
||||
}
|
||||
|
||||
func (c *MockTerminal) Read(data []byte) (n int, err error) {
|
||||
n = len(data)
|
||||
if n == 0 {
|
||||
return
|
||||
}
|
||||
if n > len(c.toSend) {
|
||||
n = len(c.toSend)
|
||||
}
|
||||
if n == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
if c.bytesPerRead > 0 && n > c.bytesPerRead {
|
||||
n = c.bytesPerRead
|
||||
}
|
||||
copy(data, c.toSend[:n])
|
||||
c.toSend = c.toSend[n:]
|
||||
return
|
||||
}
|
||||
|
||||
func (c *MockTerminal) Write(data []byte) (n int, err error) {
|
||||
c.received = append(c.received, data...)
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func TestKeyboardInteractive(t *testing.T) {
|
||||
type args struct {
|
||||
user string
|
||||
instruction string
|
||||
questions []string
|
||||
echos []bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "questions are none",
|
||||
args: args{
|
||||
questions: []string{},
|
||||
},
|
||||
want: []string{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "input answer interactive",
|
||||
args: args{
|
||||
questions: []string{"this is question"},
|
||||
},
|
||||
want: []string{"xxxx"},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &MockTerminal{
|
||||
toSend: []byte("xxxx\r\x1b[A\r"),
|
||||
bytesPerRead: 1,
|
||||
}
|
||||
f := KeyboardInteractive(c)
|
||||
got, err := f(tt.args.user, tt.args.instruction, tt.args.questions, tt.args.echos)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("KeyboardInteractive error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("KeyboardInteractive = %v, want %v", got, tt.want)
|
||||
}
|
||||
log.Printf("finish")
|
||||
})
|
||||
}
|
||||
}
|
|
@ -4,6 +4,8 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"golang.org/x/crypto/ssh/terminal"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
|
@ -247,7 +249,18 @@ func sshBastionConfig(config *Config) (*gossh.ClientConfig, error) {
|
|||
auth := make([]gossh.AuthMethod, 0, 2)
|
||||
|
||||
if config.SSHBastionInteractive {
|
||||
auth = append(auth, gossh.KeyboardInteractive(ssh.KeyboardInteractive()))
|
||||
var c io.ReadWriteCloser
|
||||
if terminal.IsTerminal(int(os.Stdin.Fd())) {
|
||||
c = os.Stdin
|
||||
} else {
|
||||
tty, err := os.Open("/dev/tty")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tty.Close()
|
||||
c = tty
|
||||
}
|
||||
auth = append(auth, gossh.KeyboardInteractive(ssh.KeyboardInteractive(c)))
|
||||
}
|
||||
|
||||
if config.SSHBastionPassword != "" {
|
||||
|
|
Loading…
Reference in New Issue