packer-cn/packer/rpc/muxconn_test.go

255 lines
4.2 KiB
Go
Raw Normal View History

2013-12-08 21:20:27 -05:00
package rpc
import (
"io"
"net"
"sync"
"testing"
)
func readStream(t *testing.T, s io.Reader) string {
var data [1024]byte
n, err := s.Read(data[:])
if err != nil {
t.Fatalf("err: %s", err)
}
return string(data[0:n])
}
func testMux(t *testing.T) (client *MuxConn, server *MuxConn) {
2013-12-08 21:20:27 -05:00
l, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("err: %s", err)
}
// Server side
2013-12-08 21:20:27 -05:00
doneCh := make(chan struct{})
go func() {
defer close(doneCh)
conn, err := l.Accept()
l.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
server = NewMuxConn(conn)
}()
// Client side
conn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("err: %s", err)
}
client = NewMuxConn(conn)
// Wait for the server
<-doneCh
return
}
func TestMuxConn(t *testing.T) {
client, server := testMux(t)
defer client.Close()
defer server.Close()
// When the server is done
doneCh := make(chan struct{})
// The server side
go func() {
defer close(doneCh)
s0, err := server.Accept(0)
2013-12-08 21:20:27 -05:00
if err != nil {
t.Fatalf("err: %s", err)
}
s1, err := server.Dial(1)
2013-12-08 21:20:27 -05:00
if err != nil {
t.Fatalf("err: %s", err)
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
data := readStream(t, s1)
if data != "another" {
t.Fatalf("bad: %#v", data)
}
}()
go func() {
defer wg.Done()
data := readStream(t, s0)
if data != "hello" {
t.Fatalf("bad: %#v", data)
}
}()
wg.Wait()
}()
s0, err := client.Dial(0)
2013-12-08 21:20:27 -05:00
if err != nil {
t.Fatalf("err: %s", err)
}
s1, err := client.Accept(1)
2013-12-08 21:20:27 -05:00
if err != nil {
t.Fatalf("err: %s", err)
}
if _, err := s0.Write([]byte("hello")); err != nil {
t.Fatalf("err: %s", err)
}
if _, err := s1.Write([]byte("another")); err != nil {
t.Fatalf("err: %s", err)
}
// Wait for the server to be done
<-doneCh
}
// This tests that even when the client end is closed, data can be
// read from the server.
func TestMuxConn_clientCloseRead(t *testing.T) {
client, server := testMux(t)
defer client.Close()
defer server.Close()
// This channel will be closed when we close
waitCh := make(chan struct{})
go func() {
conn, err := server.Accept(0)
if err != nil {
t.Fatalf("err: %s", err)
}
<-waitCh
_, err = conn.Write([]byte("foo"))
if err != nil {
t.Fatalf("err: %s", err)
}
conn.Close()
}()
s0, err := client.Dial(0)
if err != nil {
t.Fatalf("err: %s", err)
}
if err := s0.Close(); err != nil {
t.Fatalf("bad: %s", err)
}
// Close this to continue on on the server-side
close(waitCh)
var data [1024]byte
n, err := s0.Read(data[:])
if string(data[:n]) != "foo" {
t.Fatalf("bad: %#v", string(data[:n]))
}
}
func TestMuxConn_socketClose(t *testing.T) {
client, server := testMux(t)
defer client.Close()
defer server.Close()
go func() {
_, err := server.Accept(0)
if err != nil {
t.Fatalf("err: %s", err)
}
server.rwc.Close()
}()
s0, err := client.Dial(0)
if err != nil {
t.Fatalf("err: %s", err)
}
var data [1024]byte
_, err = s0.Read(data[:])
if err != io.EOF {
t.Fatalf("err: %s", err)
}
}
func TestMuxConn_clientClosesStreams(t *testing.T) {
client, server := testMux(t)
defer client.Close()
defer server.Close()
go func() {
conn, err := server.Accept(0)
if err != nil {
t.Fatalf("err: %s", err)
}
conn.Close()
}()
s0, err := client.Dial(0)
if err != nil {
t.Fatalf("err: %s", err)
}
var data [1024]byte
_, err = s0.Read(data[:])
if err != io.EOF {
t.Fatalf("err: %s", err)
}
}
func TestMuxConn_serverClosesStreams(t *testing.T) {
client, server := testMux(t)
defer client.Close()
defer server.Close()
go server.Accept(0)
s0, err := client.Dial(0)
if err != nil {
t.Fatalf("err: %s", err)
}
if err := server.Close(); err != nil {
t.Fatalf("err: %s", err)
}
// This should block forever since we never write onto this stream.
var data [1024]byte
_, err = s0.Read(data[:])
if err != io.EOF {
t.Fatalf("err: %s", err)
}
}
func TestMuxConnNextId(t *testing.T) {
client, server := testMux(t)
defer client.Close()
defer server.Close()
a := client.NextId()
b := client.NextId()
if a != 1 || b != 2 {
t.Fatalf("IDs should increment")
}
a = server.NextId()
b = server.NextId()
if a != 1 || b != 2 {
t.Fatalf("IDs should increment: %d %d", a, b)
}
}