write unit test
This commit is contained in:
parent
f50ff1d270
commit
70e3f60539
|
@ -1,14 +1,15 @@
|
||||||
package ssh
|
package ssh
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
"golang.org/x/crypto/ssh/terminal"
|
"golang.org/x/crypto/ssh/terminal"
|
||||||
)
|
)
|
||||||
|
|
||||||
func KeyboardInteractive() ssh.KeyboardInteractiveChallenge {
|
func KeyboardInteractive(c io.ReadWriter) ssh.KeyboardInteractiveChallenge {
|
||||||
|
t := terminal.NewTerminal(c, "")
|
||||||
return func(user, instruction string, questions []string, echos []bool) ([]string, error) {
|
return func(user, instruction string, questions []string, echos []bool) ([]string, error) {
|
||||||
if len(questions) == 0 {
|
if len(questions) == 0 {
|
||||||
return []string{}, nil
|
return []string{}, nil
|
||||||
|
@ -21,18 +22,7 @@ func KeyboardInteractive() ssh.KeyboardInteractiveChallenge {
|
||||||
}
|
}
|
||||||
answers := make([]string, len(questions))
|
answers := make([]string, len(questions))
|
||||||
for i := range questions {
|
for i := range questions {
|
||||||
var fd int
|
s, err := t.ReadPassword("")
|
||||||
if terminal.IsTerminal(int(os.Stdin.Fd())) {
|
|
||||||
fd = int(os.Stdin.Fd())
|
|
||||||
} else {
|
|
||||||
tty, err := os.Open("/dev/tty")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer tty.Close()
|
|
||||||
fd = int(tty.Fd())
|
|
||||||
}
|
|
||||||
s, err := terminal.ReadPassword(fd)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,89 @@
|
||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MockTerminal struct {
|
||||||
|
toSend []byte
|
||||||
|
bytesPerRead int
|
||||||
|
received []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *MockTerminal) Read(data []byte) (n int, err error) {
|
||||||
|
n = len(data)
|
||||||
|
if n == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if n > len(c.toSend) {
|
||||||
|
n = len(c.toSend)
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
if c.bytesPerRead > 0 && n > c.bytesPerRead {
|
||||||
|
n = c.bytesPerRead
|
||||||
|
}
|
||||||
|
copy(data, c.toSend[:n])
|
||||||
|
c.toSend = c.toSend[n:]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *MockTerminal) Write(data []byte) (n int, err error) {
|
||||||
|
c.received = append(c.received, data...)
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyboardInteractive(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
user string
|
||||||
|
instruction string
|
||||||
|
questions []string
|
||||||
|
echos []bool
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want []string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "questions are none",
|
||||||
|
args: args{
|
||||||
|
questions: []string{},
|
||||||
|
},
|
||||||
|
want: []string{},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "input answer interactive",
|
||||||
|
args: args{
|
||||||
|
questions: []string{"this is question"},
|
||||||
|
},
|
||||||
|
want: []string{"xxxx"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := &MockTerminal{
|
||||||
|
toSend: []byte("xxxx\r\x1b[A\r"),
|
||||||
|
bytesPerRead: 1,
|
||||||
|
}
|
||||||
|
f := KeyboardInteractive(c)
|
||||||
|
got, err := f(tt.args.user, tt.args.instruction, tt.args.questions, tt.args.echos)
|
||||||
|
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("KeyboardInteractive error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("KeyboardInteractive = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
log.Printf("finish")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -4,6 +4,8 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"golang.org/x/crypto/ssh/terminal"
|
||||||
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
@ -247,7 +249,18 @@ func sshBastionConfig(config *Config) (*gossh.ClientConfig, error) {
|
||||||
auth := make([]gossh.AuthMethod, 0, 2)
|
auth := make([]gossh.AuthMethod, 0, 2)
|
||||||
|
|
||||||
if config.SSHBastionInteractive {
|
if config.SSHBastionInteractive {
|
||||||
auth = append(auth, gossh.KeyboardInteractive(ssh.KeyboardInteractive()))
|
var c io.ReadWriteCloser
|
||||||
|
if terminal.IsTerminal(int(os.Stdin.Fd())) {
|
||||||
|
c = os.Stdin
|
||||||
|
} else {
|
||||||
|
tty, err := os.Open("/dev/tty")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer tty.Close()
|
||||||
|
c = tty
|
||||||
|
}
|
||||||
|
auth = append(auth, gossh.KeyboardInteractive(ssh.KeyboardInteractive(c)))
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.SSHBastionPassword != "" {
|
if config.SSHBastionPassword != "" {
|
||||||
|
|
Loading…
Reference in New Issue