diff --git a/jetty-websocket/websocket-client/src/main/java/org/eclipse/jetty/websocket/client/io/UpgradeConnection.java b/jetty-websocket/websocket-client/src/main/java/org/eclipse/jetty/websocket/client/io/UpgradeConnection.java index cb362fe547f..3bba82fc79e 100644 --- a/jetty-websocket/websocket-client/src/main/java/org/eclipse/jetty/websocket/client/io/UpgradeConnection.java +++ b/jetty-websocket/websocket-client/src/main/java/org/eclipse/jetty/websocket/client/io/UpgradeConnection.java @@ -74,6 +74,9 @@ public class UpgradeConnection extends AbstractConnection } } + /** HTTP Response Code: 101 Switching Protocols */ + private static final int SWITCHING_PROTOCOLS = 101; + private static final Logger LOG = Log.getLogger(UpgradeConnection.class); private final ByteBufferPool bufferPool; private final ConnectPromise connectPromise; @@ -239,6 +242,19 @@ public class UpgradeConnection extends AbstractConnection private void validateResponse(ClientUpgradeResponse response) { + // Validate Response Status Code + if (response.getStatusCode() != SWITCHING_PROTOCOLS) + { + throw new UpgradeException("Didn't switch protocols"); + } + + // Validate Connection header + String connection = response.getHeader("Connection"); + if (!"upgrade".equalsIgnoreCase(connection)) + { + throw new UpgradeException("Connection is " + connection + " (expected upgrade)"); + } + // Check the Accept hash String reqKey = request.getKey(); String expectedHash = AcceptHash.hashKey(reqKey); diff --git a/jetty-websocket/websocket-client/src/test/java/org/eclipse/jetty/websocket/client/ClientConnectTest.java b/jetty-websocket/websocket-client/src/test/java/org/eclipse/jetty/websocket/client/ClientConnectTest.java index 15e44035995..d17adfbb07d 100644 --- a/jetty-websocket/websocket-client/src/test/java/org/eclipse/jetty/websocket/client/ClientConnectTest.java +++ b/jetty-websocket/websocket-client/src/test/java/org/eclipse/jetty/websocket/client/ClientConnectTest.java @@ -20,6 +20,7 @@ package org.eclipse.jetty.websocket.client; import java.net.ConnectException; import java.net.URI; +import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; @@ -31,6 +32,7 @@ import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.UpgradeException; import org.eclipse.jetty.websocket.client.blockhead.BlockheadServer; import org.eclipse.jetty.websocket.client.blockhead.BlockheadServer.ServerConnection; +import org.eclipse.jetty.websocket.common.AcceptHash; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -102,6 +104,131 @@ public class ClientConnectTest } } + @Test(expected = UpgradeException.class) + public void testBadHandshake_GetOK() throws Exception + { + TrackingSocket wsocket = new TrackingSocket(); + + URI wsUri = server.getWsUri(); + Future future = client.connect(wsocket,wsUri); + + ServerConnection connection = server.accept(); + connection.readRequest(); + // Send OK to GET but not upgrade + connection.respond("HTTP/1.1 200 OK\r\n\r\n"); + + // The attempt to get upgrade response future should throw error + try + { + future.get(500,TimeUnit.MILLISECONDS); + Assert.fail("Expected ExecutionException -> UpgradeException"); + } + catch (ExecutionException e) + { + // Expected Path - throw underlying exception + FutureCallback.rethrow(e); + } + } + + @Test(expected = UpgradeException.class) + public void testBadHandshake_GetOK_WithSecWebSocketAccept() throws Exception + { + TrackingSocket wsocket = new TrackingSocket(); + + URI wsUri = server.getWsUri(); + Future future = client.connect(wsocket,wsUri); + + ServerConnection connection = server.accept(); + List requestLines = connection.readRequestLines(); + String key = connection.parseWebSocketKey(requestLines); + + // Send OK to GET but not upgrade + StringBuilder resp = new StringBuilder(); + resp.append("HTTP/1.1 200 OK\r\n"); // intentionally 200 (not 101) + // Include a value accept key + resp.append("Sec-WebSocket-Accept: ").append(AcceptHash.hashKey(key)).append("\r\n"); + resp.append("\r\n"); + connection.respond(resp.toString()); + + // The attempt to get upgrade response future should throw error + try + { + future.get(500,TimeUnit.MILLISECONDS); + Assert.fail("Expected ExecutionException -> UpgradeException"); + } + catch (ExecutionException e) + { + // Expected Path - throw underlying exception + FutureCallback.rethrow(e); + } + } + + @Test(expected = UpgradeException.class) + public void testBadHandshake_SwitchingProtocols_InvalidConnectionHeader() throws Exception + { + TrackingSocket wsocket = new TrackingSocket(); + + URI wsUri = server.getWsUri(); + Future future = client.connect(wsocket,wsUri); + + ServerConnection connection = server.accept(); + List requestLines = connection.readRequestLines(); + String key = connection.parseWebSocketKey(requestLines); + + // Send Switching Protocols 101, but invalid 'Connection' header + StringBuilder resp = new StringBuilder(); + resp.append("HTTP/1.1 101 Switching Protocols\r\n"); + resp.append("Sec-WebSocket-Accept: ").append(AcceptHash.hashKey(key)).append("\r\n"); + resp.append("Connection: close\r\n"); + resp.append("\r\n"); + connection.respond(resp.toString()); + + // The attempt to get upgrade response future should throw error + try + { + future.get(500,TimeUnit.MILLISECONDS); + Assert.fail("Expected ExecutionException -> UpgradeException"); + } + catch (ExecutionException e) + { + // Expected Path - throw underlying exception + FutureCallback.rethrow(e); + } + } + + @Test(expected = UpgradeException.class) + public void testBadHandshake_SwitchingProtocols_NoConnectionHeader() throws Exception + { + TrackingSocket wsocket = new TrackingSocket(); + + URI wsUri = server.getWsUri(); + Future future = client.connect(wsocket,wsUri); + + ServerConnection connection = server.accept(); + List requestLines = connection.readRequestLines(); + String key = connection.parseWebSocketKey(requestLines); + + // Send Switching Protocols 101, but no 'Connection' header + StringBuilder resp = new StringBuilder(); + resp.append("HTTP/1.1 101 Switching Protocols\r\n"); + resp.append("Sec-WebSocket-Accept: ").append(AcceptHash.hashKey(key)).append("\r\n"); + // Intentionally leave out Connection header + resp.append("\r\n"); + connection.respond(resp.toString()); + + // The attempt to get upgrade response future should throw error + try + { + future.get(500,TimeUnit.MILLISECONDS); + Assert.fail("Expected ExecutionException -> UpgradeException"); + } + catch (ExecutionException e) + { + // Expected Path - throw underlying exception + FutureCallback.rethrow(e); + } + } + @Test(expected = UpgradeException.class) public void testBadUpgrade() throws Exception { diff --git a/jetty-websocket/websocket-client/src/test/java/org/eclipse/jetty/websocket/client/blockhead/BlockheadServer.java b/jetty-websocket/websocket-client/src/test/java/org/eclipse/jetty/websocket/client/blockhead/BlockheadServer.java index f1c9ef9effc..2b20ee85bf9 100644 --- a/jetty-websocket/websocket-client/src/test/java/org/eclipse/jetty/websocket/client/blockhead/BlockheadServer.java +++ b/jetty-websocket/websocket-client/src/test/java/org/eclipse/jetty/websocket/client/blockhead/BlockheadServer.java @@ -250,6 +250,47 @@ public class BlockheadServer } } + public List parseExtensions(List requestLines) + { + List extensionConfigs = new ArrayList<>(); + + Pattern patExts = Pattern.compile("^Sec-WebSocket-Extensions: (.*)$",Pattern.CASE_INSENSITIVE); + + Matcher mat; + for (String line : requestLines) + { + mat = patExts.matcher(line); + if (mat.matches()) + { + // found extensions + String econf = mat.group(1); + ExtensionConfig config = ExtensionConfig.parse(econf); + extensionConfigs.add(config); + } + } + + return extensionConfigs; + } + + public String parseWebSocketKey(List requestLines) + { + String key = null; + + Pattern patKey = Pattern.compile("^Sec-WebSocket-Key: (.*)$",Pattern.CASE_INSENSITIVE); + + Matcher mat; + for (String line : requestLines) + { + mat = patKey.matcher(line); + if (mat.matches()) + { + key = mat.group(1); + } + } + + return key; + } + public int read(ByteBuffer buf) throws IOException { int len = 0; @@ -329,6 +370,24 @@ public class BlockheadServer return request.toString(); } + public List readRequestLines() throws IOException + { + LOG.debug("Reading client request header"); + List lines = new ArrayList<>(); + + BufferedReader in = new BufferedReader(new InputStreamReader(getInputStream())); + for (String line = in.readLine(); line != null; line = in.readLine()) + { + if (line.length() == 0) + { + break; + } + lines.add(line); + } + + return lines; + } + public void respond(String rawstr) throws IOException { LOG.debug("respond(){}{}","\n",rawstr); @@ -343,39 +402,14 @@ public class BlockheadServer public void upgrade() throws IOException { - List extensionConfigs = new ArrayList<>(); + List requestLines = readRequestLines(); + List extensionConfigs = parseExtensions(requestLines); + String key = parseWebSocketKey(requestLines); - Pattern patExts = Pattern.compile("^Sec-WebSocket-Extensions: (.*)$",Pattern.CASE_INSENSITIVE); - Pattern patKey = Pattern.compile("^Sec-WebSocket-Key: (.*)$",Pattern.CASE_INSENSITIVE); + LOG.debug("Client Request Extensions: {}",extensionConfigs); + LOG.debug("Client Request Key: {}",key); - Matcher mat; - String key = "not sent"; - BufferedReader in = new BufferedReader(new InputStreamReader(getInputStream())); - for (String line = in.readLine(); line != null; line = in.readLine()) - { - if (line.length() == 0) - { - break; - } - - // Check for extensions - mat = patExts.matcher(line); - if (mat.matches()) - { - // found extensions - String econf = mat.group(1); - ExtensionConfig config = ExtensionConfig.parse(econf); - extensionConfigs.add(config); - continue; - } - - // Check for Key - mat = patKey.matcher(line); - if (mat.matches()) - { - key = mat.group(1); - } - } + Assert.assertThat("Request: Sec-WebSocket-Key",key,notNullValue()); // collect extensions configured in response header ExtensionStack extensionStack = new ExtensionStack(extensionRegistry); @@ -405,6 +439,7 @@ public class BlockheadServer // Setup Response StringBuilder resp = new StringBuilder(); resp.append("HTTP/1.1 101 Upgrade\r\n"); + resp.append("Connection: upgrade\r\n"); resp.append("Sec-WebSocket-Accept: "); resp.append(AcceptHash.hashKey(key)).append("\r\n"); if (!extensionStack.hasNegotiatedExtensions()) @@ -436,6 +471,7 @@ public class BlockheadServer resp.append("\r\n"); // Write Response + LOG.debug("Response: {}",resp.toString()); write(resp.toString().getBytes()); }