Initial pass at BlockheadClient and SimpleServletServer to aid in unit testing

This commit is contained in:
Joakim Erdfelt 2012-06-29 10:08:10 -07:00
parent 9de009c1b2
commit f28324e31a
3 changed files with 326 additions and 2 deletions

View File

@ -0,0 +1,66 @@
package org.eclipse.jetty.websocket.server;
import java.net.URI;
import javax.servlet.http.HttpServlet;
import org.eclipse.jetty.server.SelectChannelConnector;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.servlet.ServletHolder;
public class SimpleServletServer
{
private Server server;
private SelectChannelConnector connector;
private URI serverUri;
private HttpServlet servlet;
public SimpleServletServer(HttpServlet servlet)
{
this.servlet = servlet;
}
public URI getServerUri()
{
return serverUri;
}
public void start() throws Exception
{
// Configure Server
server = new Server();
connector = new SelectChannelConnector();
server.addConnector(connector);
ServletContextHandler context = new ServletContextHandler();
context.setContextPath("/");
server.setHandler(context);
// Serve capture servlet
context.addServlet(new ServletHolder(servlet),"/*");
// Start Server
server.start();
String host = connector.getHost();
if (host == null)
{
host = "localhost";
}
int port = connector.getLocalPort();
serverUri = new URI(String.format("ws://%s:%d/",host,port));
}
public void stop()
{
try
{
server.stop();
}
catch (Exception e)
{
e.printStackTrace(System.err);
}
}
}

View File

@ -0,0 +1,50 @@
package org.eclipse.jetty.websocket.server;
import static org.hamcrest.Matchers.*;
import org.eclipse.jetty.websocket.server.blockhead.BlockheadClient;
import org.eclipse.jetty.websocket.server.examples.MyEchoServlet;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
public class WebSocketInvalidVersionTest
{
private static SimpleServletServer server;
@BeforeClass
public static void startServer() throws Exception
{
server = new SimpleServletServer(new MyEchoServlet());
server.start();
}
@AfterClass
public static void stopServer()
{
server.stop();
}
/**
* Test the requirement of responding with an http 400 when using a Sec-WebSocket-Version that is unsupported.
*/
@Test
public void testRequestVersion29() throws Exception
{
BlockheadClient client = new BlockheadClient(server.getServerUri());
client.setVersion(29); // intentionally bad version
try
{
client.connect();
client.sendStandardRequest();
String respHeader = client.readResponseHeader();
Assert.assertThat("Response Code",respHeader,startsWith("HTTP/1.1 400 Unsupported websocket version specification"));
Assert.assertThat("Response Header Versions",respHeader,containsString("Sec-WebSocket-Version: 13, 0\r\n"));
}
finally
{
client.close();
}
}
}

View File

@ -1,18 +1,31 @@
package org.eclipse.jetty.websocket.server.blockhead;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
import java.io.BufferedReader;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.InetAddress;
import java.net.Socket;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.util.Queue;
import java.util.concurrent.LinkedBlockingDeque;
import javax.net.ssl.HttpsURLConnection;
import org.eclipse.jetty.io.ByteBufferPool;
import org.eclipse.jetty.io.StandardByteBufferPool;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.util.Utf8StringBuilder;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.websocket.api.WebSocketException;
@ -41,6 +54,13 @@ public class BlockheadClient implements Parser.Listener
private final Parser parser;
private final LinkedBlockingDeque<BaseFrame> incomingFrameQueue;
private Socket socket;
private OutputStream out;
private InputStream in;
private int version = 13; // default to RFC-6455
private String protocols;
private String extensions;
public BlockheadClient(URI destWebsocketURI) throws URISyntaxException
{
Assert.assertThat("Websocket URI scheme",destWebsocketURI.getScheme(),anyOf(is("ws"),is("wss")));
@ -61,9 +81,34 @@ public class BlockheadClient implements Parser.Listener
incomingFrameQueue = new LinkedBlockingDeque<>();
}
public void close()
{
IO.close(in);
IO.close(out);
try
{
socket.close();
}
catch (IOException ignore)
{
/* ignore */
}
}
public void connect() throws IOException
{
InetAddress destAddr = InetAddress.getByName(destHttpURI.getHost());
int port = destHttpURI.getPort();
socket = new Socket(destAddr,port);
out = socket.getOutputStream();
socket.setSoTimeout(1000);
in = socket.getInputStream();
}
public String getExtensions()
{
return extensions;
}
public URI getHttpURI()
@ -71,11 +116,50 @@ public class BlockheadClient implements Parser.Listener
return destHttpURI;
}
public String getProtocols()
{
return protocols;
}
public int getVersion()
{
return version;
}
public URI getWebsocketURI()
{
return destWebsocketURI;
}
public void lookFor(String string) throws IOException
{
String orig = string;
Utf8StringBuilder scanned = new Utf8StringBuilder();
try
{
while (true)
{
int b = in.read();
if (b < 0)
{
throw new EOFException();
}
scanned.append((byte)b);
assertEquals("looking for\"" + orig + "\" in '" + scanned + "'",string.charAt(0),b);
if (string.length() == 1)
{
break;
}
string = string.substring(1);
}
}
catch (IOException e)
{
System.err.println("IOE while looking for \"" + orig + "\" in '" + scanned + "'");
throw e;
}
}
@Override
public void onFrame(BaseFrame frame)
{
@ -91,17 +175,141 @@ public class BlockheadClient implements Parser.Listener
LOG.warn(e);
}
public void write(BaseFrame frame)
private void read(ByteBuffer buf) throws IOException
{
while ((in.available() > 0) && (buf.remaining() > 0))
{
buf.put((byte)in.read());
}
}
public Queue<BaseFrame> readFrames(int expectedCount) throws IOException
{
int startCount = incomingFrameQueue.size();
ByteBuffer buf = bufferPool.acquire(policy.getBufferSize(),false);
try
{
while (incomingFrameQueue.size() < (startCount + expectedCount))
{
read(buf);
parser.parse(buf);
}
}
finally
{
bufferPool.release(buf);
}
return incomingFrameQueue;
}
public String readResponseHeader() throws IOException
{
InputStreamReader isr = new InputStreamReader(in);
BufferedReader reader = new BufferedReader(isr);
StringBuilder header = new StringBuilder();
// Read the response header
String line = reader.readLine();
Assert.assertNotNull(line);
Assert.assertThat(line,startsWith("HTTP/1.1 "));
header.append(line).append("\r\n");
while ((line = reader.readLine()) != null)
{
if (line.trim().length() == 0)
{
break;
}
header.append(line).append("\r\n");
}
return header.toString();
}
public void sendStandardRequest() throws IOException
{
StringBuilder req = new StringBuilder();
req.append("GET /chat HTTP/1.1\r\n");
req.append("Host: ").append(destWebsocketURI.getHost());
if (destWebsocketURI.getPort() > 0)
{
req.append(':').append(destWebsocketURI.getPort());
}
req.append("\r\n");
req.append("Upgrade: websocket\r\n");
req.append("Connection: Upgrade\r\n");
req.append("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n");
req.append("Sec-WebSocket-Origin: ").append(destWebsocketURI.toASCIIString()).append("\r\n");
if (StringUtil.isNotBlank(protocols))
{
req.append("Sec-WebSocket-Protocol: ").append(protocols).append("\r\n");
}
if (StringUtil.isNotBlank(extensions))
{
req.append("Sec-WebSocket-Extensions: ").append(extensions).append("\r\n");
}
req.append("Sec-WebSocket-Version: ").append(version).append("\r\n");
req.append("\r\n");
write(req.toString());
}
public void setExtensions(String extensions)
{
this.extensions = extensions;
}
public void setProtocols(String protocols)
{
this.protocols = protocols;
}
public void setVersion(int version)
{
this.version = version;
}
public void skipTo(String string) throws IOException
{
int state = 0;
while (true)
{
int b = in.read();
if (b < 0)
{
throw new EOFException();
}
if (b == string.charAt(state))
{
state++;
if (state == string.length())
{
break;
}
}
else
{
state = 0;
}
}
}
public void write(BaseFrame frame) throws IOException
{
ByteBuffer buf = bufferPool.acquire(policy.getBufferSize(),false);
try
{
generator.generate(buf,frame);
// TODO write to Socket
out.write(BufferUtil.toArray(buf));
}
finally
{
bufferPool.release(buf);
}
}
public void write(String str) throws IOException
{
out.write(StringUtil.getBytes(str,StringUtil.__ISO_8859_1));
}
}