diff --git a/jetty-websocket/websocket-server/src/test/java/org/eclipse/jetty/websocket/server/SimpleServletServer.java b/jetty-websocket/websocket-server/src/test/java/org/eclipse/jetty/websocket/server/SimpleServletServer.java new file mode 100644 index 00000000000..3e717f1ae71 --- /dev/null +++ b/jetty-websocket/websocket-server/src/test/java/org/eclipse/jetty/websocket/server/SimpleServletServer.java @@ -0,0 +1,66 @@ +package org.eclipse.jetty.websocket.server; + +import java.net.URI; + +import javax.servlet.http.HttpServlet; + +import org.eclipse.jetty.server.SelectChannelConnector; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.servlet.ServletHolder; + +public class SimpleServletServer +{ + private Server server; + private SelectChannelConnector connector; + private URI serverUri; + private HttpServlet servlet; + + public SimpleServletServer(HttpServlet servlet) + { + this.servlet = servlet; + } + + public URI getServerUri() + { + return serverUri; + } + + public void start() throws Exception + { + // Configure Server + server = new Server(); + connector = new SelectChannelConnector(); + server.addConnector(connector); + + ServletContextHandler context = new ServletContextHandler(); + context.setContextPath("/"); + server.setHandler(context); + + // Serve capture servlet + context.addServlet(new ServletHolder(servlet),"/*"); + + // Start Server + server.start(); + + String host = connector.getHost(); + if (host == null) + { + host = "localhost"; + } + int port = connector.getLocalPort(); + serverUri = new URI(String.format("ws://%s:%d/",host,port)); + } + + public void stop() + { + try + { + server.stop(); + } + catch (Exception e) + { + e.printStackTrace(System.err); + } + } +} diff --git a/jetty-websocket/websocket-server/src/test/java/org/eclipse/jetty/websocket/server/WebSocketInvalidVersionTest.java b/jetty-websocket/websocket-server/src/test/java/org/eclipse/jetty/websocket/server/WebSocketInvalidVersionTest.java new file mode 100644 index 00000000000..dcd15bb785d --- /dev/null +++ b/jetty-websocket/websocket-server/src/test/java/org/eclipse/jetty/websocket/server/WebSocketInvalidVersionTest.java @@ -0,0 +1,50 @@ +package org.eclipse.jetty.websocket.server; + +import static org.hamcrest.Matchers.*; + +import org.eclipse.jetty.websocket.server.blockhead.BlockheadClient; +import org.eclipse.jetty.websocket.server.examples.MyEchoServlet; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +public class WebSocketInvalidVersionTest +{ + private static SimpleServletServer server; + + @BeforeClass + public static void startServer() throws Exception + { + server = new SimpleServletServer(new MyEchoServlet()); + server.start(); + } + + @AfterClass + public static void stopServer() + { + server.stop(); + } + + /** + * Test the requirement of responding with an http 400 when using a Sec-WebSocket-Version that is unsupported. + */ + @Test + public void testRequestVersion29() throws Exception + { + BlockheadClient client = new BlockheadClient(server.getServerUri()); + client.setVersion(29); // intentionally bad version + try + { + client.connect(); + client.sendStandardRequest(); + String respHeader = client.readResponseHeader(); + Assert.assertThat("Response Code",respHeader,startsWith("HTTP/1.1 400 Unsupported websocket version specification")); + Assert.assertThat("Response Header Versions",respHeader,containsString("Sec-WebSocket-Version: 13, 0\r\n")); + } + finally + { + client.close(); + } + } +} diff --git a/jetty-websocket/websocket-server/src/test/java/org/eclipse/jetty/websocket/server/blockhead/BlockheadClient.java b/jetty-websocket/websocket-server/src/test/java/org/eclipse/jetty/websocket/server/blockhead/BlockheadClient.java index a3479c952b6..009ecd10f1c 100644 --- a/jetty-websocket/websocket-server/src/test/java/org/eclipse/jetty/websocket/server/blockhead/BlockheadClient.java +++ b/jetty-websocket/websocket-server/src/test/java/org/eclipse/jetty/websocket/server/blockhead/BlockheadClient.java @@ -1,18 +1,31 @@ package org.eclipse.jetty.websocket.server.blockhead; import static org.hamcrest.Matchers.*; +import static org.junit.Assert.*; +import java.io.BufferedReader; +import java.io.EOFException; import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; import java.net.HttpURLConnection; +import java.net.InetAddress; +import java.net.Socket; import java.net.URI; import java.net.URISyntaxException; import java.nio.ByteBuffer; +import java.util.Queue; import java.util.concurrent.LinkedBlockingDeque; import javax.net.ssl.HttpsURLConnection; import org.eclipse.jetty.io.ByteBufferPool; import org.eclipse.jetty.io.StandardByteBufferPool; +import org.eclipse.jetty.util.BufferUtil; +import org.eclipse.jetty.util.IO; +import org.eclipse.jetty.util.StringUtil; +import org.eclipse.jetty.util.Utf8StringBuilder; import org.eclipse.jetty.util.log.Log; import org.eclipse.jetty.util.log.Logger; import org.eclipse.jetty.websocket.api.WebSocketException; @@ -41,6 +54,13 @@ public class BlockheadClient implements Parser.Listener private final Parser parser; private final LinkedBlockingDeque incomingFrameQueue; + private Socket socket; + private OutputStream out; + private InputStream in; + private int version = 13; // default to RFC-6455 + private String protocols; + private String extensions; + public BlockheadClient(URI destWebsocketURI) throws URISyntaxException { Assert.assertThat("Websocket URI scheme",destWebsocketURI.getScheme(),anyOf(is("ws"),is("wss"))); @@ -61,9 +81,34 @@ public class BlockheadClient implements Parser.Listener incomingFrameQueue = new LinkedBlockingDeque<>(); } + public void close() + { + IO.close(in); + IO.close(out); + try + { + socket.close(); + } + catch (IOException ignore) + { + /* ignore */ + } + } + public void connect() throws IOException { + InetAddress destAddr = InetAddress.getByName(destHttpURI.getHost()); + int port = destHttpURI.getPort(); + socket = new Socket(destAddr,port); + out = socket.getOutputStream(); + socket.setSoTimeout(1000); + in = socket.getInputStream(); + } + + public String getExtensions() + { + return extensions; } public URI getHttpURI() @@ -71,11 +116,50 @@ public class BlockheadClient implements Parser.Listener return destHttpURI; } + public String getProtocols() + { + return protocols; + } + + public int getVersion() + { + return version; + } + public URI getWebsocketURI() { return destWebsocketURI; } + public void lookFor(String string) throws IOException + { + String orig = string; + Utf8StringBuilder scanned = new Utf8StringBuilder(); + try + { + while (true) + { + int b = in.read(); + if (b < 0) + { + throw new EOFException(); + } + scanned.append((byte)b); + assertEquals("looking for\"" + orig + "\" in '" + scanned + "'",string.charAt(0),b); + if (string.length() == 1) + { + break; + } + string = string.substring(1); + } + } + catch (IOException e) + { + System.err.println("IOE while looking for \"" + orig + "\" in '" + scanned + "'"); + throw e; + } + } + @Override public void onFrame(BaseFrame frame) { @@ -91,17 +175,141 @@ public class BlockheadClient implements Parser.Listener LOG.warn(e); } - public void write(BaseFrame frame) + private void read(ByteBuffer buf) throws IOException + { + while ((in.available() > 0) && (buf.remaining() > 0)) + { + buf.put((byte)in.read()); + } + } + + public Queue readFrames(int expectedCount) throws IOException + { + int startCount = incomingFrameQueue.size(); + + ByteBuffer buf = bufferPool.acquire(policy.getBufferSize(),false); + try + { + while (incomingFrameQueue.size() < (startCount + expectedCount)) + { + read(buf); + parser.parse(buf); + } + } + finally + { + bufferPool.release(buf); + } + + return incomingFrameQueue; + } + + public String readResponseHeader() throws IOException + { + InputStreamReader isr = new InputStreamReader(in); + BufferedReader reader = new BufferedReader(isr); + StringBuilder header = new StringBuilder(); + // Read the response header + String line = reader.readLine(); + Assert.assertNotNull(line); + Assert.assertThat(line,startsWith("HTTP/1.1 ")); + header.append(line).append("\r\n"); + while ((line = reader.readLine()) != null) + { + if (line.trim().length() == 0) + { + break; + } + header.append(line).append("\r\n"); + } + return header.toString(); + } + + public void sendStandardRequest() throws IOException + { + StringBuilder req = new StringBuilder(); + req.append("GET /chat HTTP/1.1\r\n"); + req.append("Host: ").append(destWebsocketURI.getHost()); + if (destWebsocketURI.getPort() > 0) + { + req.append(':').append(destWebsocketURI.getPort()); + } + req.append("\r\n"); + req.append("Upgrade: websocket\r\n"); + req.append("Connection: Upgrade\r\n"); + req.append("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"); + req.append("Sec-WebSocket-Origin: ").append(destWebsocketURI.toASCIIString()).append("\r\n"); + if (StringUtil.isNotBlank(protocols)) + { + req.append("Sec-WebSocket-Protocol: ").append(protocols).append("\r\n"); + } + if (StringUtil.isNotBlank(extensions)) + { + req.append("Sec-WebSocket-Extensions: ").append(extensions).append("\r\n"); + } + req.append("Sec-WebSocket-Version: ").append(version).append("\r\n"); + req.append("\r\n"); + write(req.toString()); + } + + public void setExtensions(String extensions) + { + this.extensions = extensions; + } + + public void setProtocols(String protocols) + { + this.protocols = protocols; + } + + public void setVersion(int version) + { + this.version = version; + } + + public void skipTo(String string) throws IOException + { + int state = 0; + + while (true) + { + int b = in.read(); + if (b < 0) + { + throw new EOFException(); + } + + if (b == string.charAt(state)) + { + state++; + if (state == string.length()) + { + break; + } + } + else + { + state = 0; + } + } + } + + public void write(BaseFrame frame) throws IOException { ByteBuffer buf = bufferPool.acquire(policy.getBufferSize(),false); try { generator.generate(buf,frame); - // TODO write to Socket + out.write(BufferUtil.toArray(buf)); } finally { bufferPool.release(buf); } } + + public void write(String str) throws IOException + { + out.write(StringUtil.getBytes(str,StringUtil.__ISO_8859_1)); + } }