Merge pull request #3470 from lachlan-roberts/jetty-10.0.x-3462-websocketclient-validation

Issue #3462 - websocket upgrade request valdiation
This commit is contained in:
Greg Wilkins 2019-03-19 16:11:09 +11:00 committed by GitHub
commit f037258725
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 106 additions and 99 deletions

View File

@ -631,7 +631,7 @@ public class ConfiguratorTest
FrameHandlerTracker clientSocket = new FrameHandlerTracker();
ClientUpgradeRequest upgradeRequest = ClientUpgradeRequest.from(client, wsUri, clientSocket);
upgradeRequest.header("sec-websocket-protocol", "echo, chat, status");
upgradeRequest.setSubProtocols("echo","chat","status");
Future<FrameHandler.CoreSession> clientConnectFuture = client.connect(upgradeRequest);
assertProtocols(clientSocket, clientConnectFuture, is("Requested Protocols: [echo,chat,status]"));
@ -650,8 +650,7 @@ public class ConfiguratorTest
FrameHandlerTracker clientSocket = new FrameHandlerTracker();
ClientUpgradeRequest upgradeRequest = ClientUpgradeRequest.from(client, wsUri, clientSocket);
// header name is not to spec (case wise)
upgradeRequest.header("Sec-Websocket-Protocol", "echo, chat, status");
upgradeRequest.setSubProtocols("echo","chat","status");
Future<FrameHandler.CoreSession> clientConnectFuture = client.connect(upgradeRequest);
assertProtocols(clientSocket, clientConnectFuture, is("Requested Protocols: [echo,chat,status]"));

View File

@ -20,13 +20,13 @@ package org.eclipse.jetty.websocket.core.client;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.eclipse.jetty.client.HttpClient;
import org.eclipse.jetty.client.HttpConversation;
@ -83,14 +83,6 @@ public abstract class ClientUpgradeRequest extends HttpRequest implements Respon
protected final CompletableFuture<FrameHandler.CoreSession> futureCoreSession;
private final WebSocketCoreClient wsClient;
private List<UpgradeListener> upgradeListeners = new ArrayList<>();
/**
* Offered Extensions
*/
private List<ExtensionConfig> extensions = new ArrayList<>();
/**
* Offered SubProtocols
*/
private List<String> subProtocols = new ArrayList<>();
public ClientUpgradeRequest(WebSocketCoreClient webSocketClient, URI requestURI)
{
@ -133,47 +125,56 @@ public abstract class ClientUpgradeRequest extends HttpRequest implements Respon
public void addExtensions(ExtensionConfig... configs)
{
HttpFields headers = getHeaders();
for (ExtensionConfig config : configs)
{
this.extensions.add(config);
}
updateWebSocketExtensionHeader();
headers.add(HttpHeader.SEC_WEBSOCKET_EXTENSIONS, config.getParameterizedName());
}
public void addExtensions(String... configs)
{
this.extensions.addAll(ExtensionConfig.parseList(configs));
updateWebSocketExtensionHeader();
HttpFields headers = getHeaders();
for (String config : configs)
headers.add(HttpHeader.SEC_WEBSOCKET_EXTENSIONS, ExtensionConfig.parse(config).getParameterizedName());
}
public List<ExtensionConfig> getExtensions()
{
List<ExtensionConfig> extensions = getHeaders().getCSV(HttpHeader.SEC_WEBSOCKET_EXTENSIONS, true)
.stream()
.map(ExtensionConfig::parse)
.collect(Collectors.toList());
return extensions;
}
public void setExtensions(List<ExtensionConfig> configs)
{
this.extensions = configs;
updateWebSocketExtensionHeader();
HttpFields headers = getHeaders();
headers.remove(HttpHeader.SEC_WEBSOCKET_EXTENSIONS);
for (ExtensionConfig config : configs)
headers.add(HttpHeader.SEC_WEBSOCKET_EXTENSIONS, config.getParameterizedName());
}
public List<String> getSubProtocols()
{
return this.subProtocols;
List<String> subProtocols = getHeaders().getCSV(HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL, true);
return subProtocols;
}
public void setSubProtocols(String... protocols)
{
this.subProtocols.clear();
this.subProtocols.addAll(Arrays.asList(protocols));
updateWebSocketSubProtocolHeader();
HttpFields headers = getHeaders();
headers.remove(HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL);
for (String protocol : protocols)
headers.add(HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL, protocol);
}
public void setSubProtocols(List<String> protocols)
{
this.subProtocols.clear();
this.subProtocols.addAll(protocols);
updateWebSocketSubProtocolHeader();
HttpFields headers = getHeaders();
headers.remove(HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL);
for (String protocol : protocols)
headers.add(HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL, protocol);
}
@Override
@ -249,25 +250,16 @@ public abstract class ClientUpgradeRequest extends HttpRequest implements Respon
public void upgrade(HttpResponse response, HttpConnectionOverHTTP httpConnection)
{
if (!this.getHeaders().get(HttpHeader.UPGRADE).equalsIgnoreCase("websocket"))
{
// Not my upgrade
throw new HttpResponseException("Not a WebSocket Upgrade", response);
}
HttpClient httpClient = wsClient.getHttpClient();
// Check the Accept hash
String reqKey = this.getHeaders().get(HttpHeader.SEC_WEBSOCKET_KEY);
String expectedHash = WebSocketCore.hashKey(reqKey);
String respHash = response.getHeaders().get(HttpHeader.SEC_WEBSOCKET_ACCEPT);
if (expectedHash.equalsIgnoreCase(respHash) == false)
{
throw new HttpResponseException("Invalid Sec-WebSocket-Accept hash (was:" + respHash + ", expected:" + expectedHash + ")", response);
}
// Verify the Negotiated Extensions
ExtensionStack extensionStack = new ExtensionStack(wsClient.getExtensionRegistry());
// Parse the Negotiated Extensions
List<ExtensionConfig> extensions = new ArrayList<>();
HttpField extField = response.getHeaders().getField(HttpHeader.SEC_WEBSOCKET_EXTENSIONS);
if (extField != null)
@ -286,9 +278,23 @@ public abstract class ClientUpgradeRequest extends HttpRequest implements Respon
}
}
// Verify the Negotiated Extensions
List<ExtensionConfig> offeredExtensions = getExtensions();
for (ExtensionConfig config : extensions)
{
long numMatch = offeredExtensions.stream().filter(c -> config.getName().equalsIgnoreCase(c.getName())).count();
if (numMatch < 1)
throw new WebSocketException("Upgrade failed: Sec-WebSocket-Extensions contained extension not requested");
if (numMatch > 1)
throw new WebSocketException("Upgrade failed: Sec-WebSocket-Extensions contained more than one extension of the same name");
}
// Negotiate the extension stack
HttpClient httpClient = wsClient.getHttpClient();
ExtensionStack extensionStack = new ExtensionStack(wsClient.getExtensionRegistry());
extensionStack.negotiate(wsClient.getObjectFactory(), httpClient.getByteBufferPool(), extensions);
// Check the negotiated subprotocol
// Get the negotiated subprotocol
String negotiatedSubProtocol = null;
HttpField subProtocolField = response.getHeaders().getField(HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL);
if (subProtocolField != null)
@ -297,20 +303,18 @@ public abstract class ClientUpgradeRequest extends HttpRequest implements Respon
if (values != null)
{
if (values.length > 1)
{
throw new WebSocketException("Too many WebSocket subprotocol's in response: " + values);
}
throw new WebSocketException("Upgrade failed: Too many WebSocket subprotocol's in response: " + values);
else if (values.length == 1)
{
negotiatedSubProtocol = values[0];
}
}
}
if (!subProtocols.isEmpty() && !subProtocols.contains(negotiatedSubProtocol))
{
throw new WebSocketException("Upgrade failed: subprotocol [" + negotiatedSubProtocol + "] not found in offered subprotocols " + subProtocols);
}
// Verify the negotiated subprotocol
List<String> offeredSubProtocols = getSubProtocols();
if (negotiatedSubProtocol == null && !offeredSubProtocols.isEmpty())
throw new WebSocketException("Upgrade failed: no subprotocol selected from offered subprotocols ");
if (negotiatedSubProtocol != null && !offeredSubProtocols.contains(negotiatedSubProtocol))
throw new WebSocketException("Upgrade failed: subprotocol [" + negotiatedSubProtocol + "] not found in offered subprotocols " + offeredSubProtocols);
// We can upgrade
EndPoint endp = httpConnection.getEndPoint();
@ -435,24 +439,4 @@ public abstract class ClientUpgradeRequest extends HttpRequest implements Respon
}
}
}
private void updateWebSocketExtensionHeader()
{
HttpFields headers = getHeaders();
headers.remove(HttpHeader.SEC_WEBSOCKET_EXTENSIONS);
for (ExtensionConfig config : extensions)
{
headers.add(HttpHeader.SEC_WEBSOCKET_EXTENSIONS, config.getParameterizedName());
}
}
private void updateWebSocketSubProtocolHeader()
{
HttpFields headers = getHeaders();
headers.remove(HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL);
for (String protocol : subProtocols)
{
headers.add(HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL, protocol);
}
}
}

View File

@ -18,6 +18,16 @@
package org.eclipse.jetty.websocket.core.server;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.ListIterator;
import java.util.Set;
import java.util.stream.Collectors;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.eclipse.jetty.http.HttpField;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.QuotedCSV;
@ -28,15 +38,6 @@ import org.eclipse.jetty.websocket.core.ExtensionConfig;
import org.eclipse.jetty.websocket.core.WebSocketExtensionRegistry;
import org.eclipse.jetty.websocket.core.internal.ExtensionStack;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.ListIterator;
import java.util.Set;
import java.util.stream.Collectors;
public class Negotiation
{
private final Request baseRequest;

View File

@ -43,6 +43,7 @@ import org.eclipse.jetty.server.Response;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.websocket.core.Behavior;
import org.eclipse.jetty.websocket.core.ExtensionConfig;
import org.eclipse.jetty.websocket.core.FrameHandler;
import org.eclipse.jetty.websocket.core.WebSocketConstants;
import org.eclipse.jetty.websocket.core.internal.Negotiated;
@ -60,8 +61,7 @@ public final class RFC6455Handshaker implements Handshaker
private static final HttpField CONNECTION_UPGRADE = new PreEncodedHttpField(HttpHeader.CONNECTION, HttpHeader.UPGRADE.asString());
private static final HttpField SERVER_VERSION = new PreEncodedHttpField(HttpHeader.SERVER, HttpConfiguration.SERVER_VERSION);
public boolean upgradeRequest(WebSocketNegotiator negotiator, HttpServletRequest request,
HttpServletResponse response,
public boolean upgradeRequest(WebSocketNegotiator negotiator, HttpServletRequest request, HttpServletResponse response,
FrameHandler.Customizer defaultCustomizer) throws IOException
{
Request baseRequest = Request.getBaseRequest(request);
@ -153,22 +153,42 @@ public final class RFC6455Handshaker implements Handshaker
return false;
}
// Check if subprotocol negotiated
// validate negotiated subprotocol
String subprotocol = negotiation.getSubprotocol();
if (negotiation.getOfferedSubprotocols().size() > 0)
if (subprotocol != null)
{
if (subprotocol == null)
{
// TODO: this message needs to be returned to Http Client
LOG.warn("not upgraded: no subprotocol selected from offered subprotocols {}: {}", negotiation.getOfferedSubprotocols(), baseRequest);
return false;
}
if (!negotiation.getOfferedSubprotocols().contains(subprotocol))
{
// TODO: this message needs to be returned to Http Client
LOG.warn("not upgraded: selected subprotocol {} not present in offered subprotocols {}: {}", subprotocol, negotiation.getOfferedSubprotocols(),
baseRequest);
LOG.warn("not upgraded: selected subprotocol {} not present in offered subprotocols {}: {}",
subprotocol, negotiation.getOfferedSubprotocols(), baseRequest);
return false;
}
}
else
{
if (!negotiation.getOfferedSubprotocols().isEmpty())
{
// TODO: this message needs to be returned to Http Client
LOG.warn("not upgraded: no subprotocol selected from offered subprotocols {}: {}",
negotiation.getOfferedSubprotocols(), baseRequest);
return false;
}
}
// validate negotiated extensions
negotiation.getOfferedExtensions();
for (ExtensionConfig config : negotiation.getNegotiatedExtensions())
{
long numMatch = negotiation.getOfferedExtensions().stream().filter(c -> config.getName().equalsIgnoreCase(c.getName())).count();
if (numMatch < 1)
{
LOG.warn("Upgrade failed: negotiated extension not requested {}: {}", config.getName(), baseRequest);
return false;
}
if (numMatch > 1)
{
LOG.warn("Upgrade failed: multiple negotiated extensions of the same name {}: {}", config.getName(), baseRequest);
return false;
}
}

View File

@ -167,19 +167,22 @@ public class ServletUpgradeResponse
public void setExtensions(List<ExtensionConfig> configs)
{
// This validation is also done later in RFC6455Handshaker but it is better to fail earlier
for (ExtensionConfig config : configs)
{
List<ExtensionConfig> collect = negotiation.getOfferedExtensions().stream()
.filter(e -> e.getName().equals(config.getName()))
.collect(Collectors.toList());
int matches = (int)negotiation.getOfferedExtensions().stream()
.filter(e -> e.getName().equals(config.getName())).count();
if (collect.size() == 1)
continue;
else if (collect.size() == 0)
switch (matches)
{
case 0:
throw new IllegalArgumentException("Extension not a requested extension");
else if (collect.size() > 1)
case 1:
continue;
default:
throw new IllegalArgumentException("Multiple extensions of the same name");
}
}
negotiation.setNegotiatedExtensions(configs);
}