Merge pull request #3470 from lachlan-roberts/jetty-10.0.x-3462-websocketclient-validation

Issue #3462 - websocket upgrade request valdiation
This commit is contained in:
Greg Wilkins 2019-03-19 16:11:09 +11:00 committed by GitHub
commit f037258725
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 106 additions and 99 deletions

View File

@ -631,7 +631,7 @@ public class ConfiguratorTest
FrameHandlerTracker clientSocket = new FrameHandlerTracker(); FrameHandlerTracker clientSocket = new FrameHandlerTracker();
ClientUpgradeRequest upgradeRequest = ClientUpgradeRequest.from(client, wsUri, clientSocket); ClientUpgradeRequest upgradeRequest = ClientUpgradeRequest.from(client, wsUri, clientSocket);
upgradeRequest.header("sec-websocket-protocol", "echo, chat, status"); upgradeRequest.setSubProtocols("echo","chat","status");
Future<FrameHandler.CoreSession> clientConnectFuture = client.connect(upgradeRequest); Future<FrameHandler.CoreSession> clientConnectFuture = client.connect(upgradeRequest);
assertProtocols(clientSocket, clientConnectFuture, is("Requested Protocols: [echo,chat,status]")); assertProtocols(clientSocket, clientConnectFuture, is("Requested Protocols: [echo,chat,status]"));
@ -650,8 +650,7 @@ public class ConfiguratorTest
FrameHandlerTracker clientSocket = new FrameHandlerTracker(); FrameHandlerTracker clientSocket = new FrameHandlerTracker();
ClientUpgradeRequest upgradeRequest = ClientUpgradeRequest.from(client, wsUri, clientSocket); ClientUpgradeRequest upgradeRequest = ClientUpgradeRequest.from(client, wsUri, clientSocket);
// header name is not to spec (case wise) upgradeRequest.setSubProtocols("echo","chat","status");
upgradeRequest.header("Sec-Websocket-Protocol", "echo, chat, status");
Future<FrameHandler.CoreSession> clientConnectFuture = client.connect(upgradeRequest); Future<FrameHandler.CoreSession> clientConnectFuture = client.connect(upgradeRequest);
assertProtocols(clientSocket, clientConnectFuture, is("Requested Protocols: [echo,chat,status]")); assertProtocols(clientSocket, clientConnectFuture, is("Requested Protocols: [echo,chat,status]"));

View File

@ -20,13 +20,13 @@ package org.eclipse.jetty.websocket.core.client;
import java.net.URI; import java.net.URI;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.eclipse.jetty.client.HttpClient; import org.eclipse.jetty.client.HttpClient;
import org.eclipse.jetty.client.HttpConversation; import org.eclipse.jetty.client.HttpConversation;
@ -83,14 +83,6 @@ public abstract class ClientUpgradeRequest extends HttpRequest implements Respon
protected final CompletableFuture<FrameHandler.CoreSession> futureCoreSession; protected final CompletableFuture<FrameHandler.CoreSession> futureCoreSession;
private final WebSocketCoreClient wsClient; private final WebSocketCoreClient wsClient;
private List<UpgradeListener> upgradeListeners = new ArrayList<>(); private List<UpgradeListener> upgradeListeners = new ArrayList<>();
/**
* Offered Extensions
*/
private List<ExtensionConfig> extensions = new ArrayList<>();
/**
* Offered SubProtocols
*/
private List<String> subProtocols = new ArrayList<>();
public ClientUpgradeRequest(WebSocketCoreClient webSocketClient, URI requestURI) public ClientUpgradeRequest(WebSocketCoreClient webSocketClient, URI requestURI)
{ {
@ -133,47 +125,56 @@ public abstract class ClientUpgradeRequest extends HttpRequest implements Respon
public void addExtensions(ExtensionConfig... configs) public void addExtensions(ExtensionConfig... configs)
{ {
HttpFields headers = getHeaders();
for (ExtensionConfig config : configs) for (ExtensionConfig config : configs)
{ headers.add(HttpHeader.SEC_WEBSOCKET_EXTENSIONS, config.getParameterizedName());
this.extensions.add(config);
}
updateWebSocketExtensionHeader();
} }
public void addExtensions(String... configs) public void addExtensions(String... configs)
{ {
this.extensions.addAll(ExtensionConfig.parseList(configs)); HttpFields headers = getHeaders();
updateWebSocketExtensionHeader(); for (String config : configs)
headers.add(HttpHeader.SEC_WEBSOCKET_EXTENSIONS, ExtensionConfig.parse(config).getParameterizedName());
} }
public List<ExtensionConfig> getExtensions() public List<ExtensionConfig> getExtensions()
{ {
List<ExtensionConfig> extensions = getHeaders().getCSV(HttpHeader.SEC_WEBSOCKET_EXTENSIONS, true)
.stream()
.map(ExtensionConfig::parse)
.collect(Collectors.toList());
return extensions; return extensions;
} }
public void setExtensions(List<ExtensionConfig> configs) public void setExtensions(List<ExtensionConfig> configs)
{ {
this.extensions = configs; HttpFields headers = getHeaders();
updateWebSocketExtensionHeader(); headers.remove(HttpHeader.SEC_WEBSOCKET_EXTENSIONS);
for (ExtensionConfig config : configs)
headers.add(HttpHeader.SEC_WEBSOCKET_EXTENSIONS, config.getParameterizedName());
} }
public List<String> getSubProtocols() public List<String> getSubProtocols()
{ {
return this.subProtocols; List<String> subProtocols = getHeaders().getCSV(HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL, true);
return subProtocols;
} }
public void setSubProtocols(String... protocols) public void setSubProtocols(String... protocols)
{ {
this.subProtocols.clear(); HttpFields headers = getHeaders();
this.subProtocols.addAll(Arrays.asList(protocols)); headers.remove(HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL);
updateWebSocketSubProtocolHeader(); for (String protocol : protocols)
headers.add(HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL, protocol);
} }
public void setSubProtocols(List<String> protocols) public void setSubProtocols(List<String> protocols)
{ {
this.subProtocols.clear(); HttpFields headers = getHeaders();
this.subProtocols.addAll(protocols); headers.remove(HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL);
updateWebSocketSubProtocolHeader(); for (String protocol : protocols)
headers.add(HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL, protocol);
} }
@Override @Override
@ -249,25 +250,16 @@ public abstract class ClientUpgradeRequest extends HttpRequest implements Respon
public void upgrade(HttpResponse response, HttpConnectionOverHTTP httpConnection) public void upgrade(HttpResponse response, HttpConnectionOverHTTP httpConnection)
{ {
if (!this.getHeaders().get(HttpHeader.UPGRADE).equalsIgnoreCase("websocket")) if (!this.getHeaders().get(HttpHeader.UPGRADE).equalsIgnoreCase("websocket"))
{
// Not my upgrade
throw new HttpResponseException("Not a WebSocket Upgrade", response); throw new HttpResponseException("Not a WebSocket Upgrade", response);
}
HttpClient httpClient = wsClient.getHttpClient();
// Check the Accept hash // Check the Accept hash
String reqKey = this.getHeaders().get(HttpHeader.SEC_WEBSOCKET_KEY); String reqKey = this.getHeaders().get(HttpHeader.SEC_WEBSOCKET_KEY);
String expectedHash = WebSocketCore.hashKey(reqKey); String expectedHash = WebSocketCore.hashKey(reqKey);
String respHash = response.getHeaders().get(HttpHeader.SEC_WEBSOCKET_ACCEPT); String respHash = response.getHeaders().get(HttpHeader.SEC_WEBSOCKET_ACCEPT);
if (expectedHash.equalsIgnoreCase(respHash) == false) if (expectedHash.equalsIgnoreCase(respHash) == false)
{
throw new HttpResponseException("Invalid Sec-WebSocket-Accept hash (was:" + respHash + ", expected:" + expectedHash + ")", response); throw new HttpResponseException("Invalid Sec-WebSocket-Accept hash (was:" + respHash + ", expected:" + expectedHash + ")", response);
}
// Verify the Negotiated Extensions // Parse the Negotiated Extensions
ExtensionStack extensionStack = new ExtensionStack(wsClient.getExtensionRegistry());
List<ExtensionConfig> extensions = new ArrayList<>(); List<ExtensionConfig> extensions = new ArrayList<>();
HttpField extField = response.getHeaders().getField(HttpHeader.SEC_WEBSOCKET_EXTENSIONS); HttpField extField = response.getHeaders().getField(HttpHeader.SEC_WEBSOCKET_EXTENSIONS);
if (extField != null) if (extField != null)
@ -286,9 +278,23 @@ public abstract class ClientUpgradeRequest extends HttpRequest implements Respon
} }
} }
// Verify the Negotiated Extensions
List<ExtensionConfig> offeredExtensions = getExtensions();
for (ExtensionConfig config : extensions)
{
long numMatch = offeredExtensions.stream().filter(c -> config.getName().equalsIgnoreCase(c.getName())).count();
if (numMatch < 1)
throw new WebSocketException("Upgrade failed: Sec-WebSocket-Extensions contained extension not requested");
if (numMatch > 1)
throw new WebSocketException("Upgrade failed: Sec-WebSocket-Extensions contained more than one extension of the same name");
}
// Negotiate the extension stack
HttpClient httpClient = wsClient.getHttpClient();
ExtensionStack extensionStack = new ExtensionStack(wsClient.getExtensionRegistry());
extensionStack.negotiate(wsClient.getObjectFactory(), httpClient.getByteBufferPool(), extensions); extensionStack.negotiate(wsClient.getObjectFactory(), httpClient.getByteBufferPool(), extensions);
// Check the negotiated subprotocol // Get the negotiated subprotocol
String negotiatedSubProtocol = null; String negotiatedSubProtocol = null;
HttpField subProtocolField = response.getHeaders().getField(HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL); HttpField subProtocolField = response.getHeaders().getField(HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL);
if (subProtocolField != null) if (subProtocolField != null)
@ -297,20 +303,18 @@ public abstract class ClientUpgradeRequest extends HttpRequest implements Respon
if (values != null) if (values != null)
{ {
if (values.length > 1) if (values.length > 1)
{ throw new WebSocketException("Upgrade failed: Too many WebSocket subprotocol's in response: " + values);
throw new WebSocketException("Too many WebSocket subprotocol's in response: " + values);
}
else if (values.length == 1) else if (values.length == 1)
{
negotiatedSubProtocol = values[0]; negotiatedSubProtocol = values[0];
} }
} }
}
if (!subProtocols.isEmpty() && !subProtocols.contains(negotiatedSubProtocol)) // Verify the negotiated subprotocol
{ List<String> offeredSubProtocols = getSubProtocols();
throw new WebSocketException("Upgrade failed: subprotocol [" + negotiatedSubProtocol + "] not found in offered subprotocols " + subProtocols); if (negotiatedSubProtocol == null && !offeredSubProtocols.isEmpty())
} throw new WebSocketException("Upgrade failed: no subprotocol selected from offered subprotocols ");
if (negotiatedSubProtocol != null && !offeredSubProtocols.contains(negotiatedSubProtocol))
throw new WebSocketException("Upgrade failed: subprotocol [" + negotiatedSubProtocol + "] not found in offered subprotocols " + offeredSubProtocols);
// We can upgrade // We can upgrade
EndPoint endp = httpConnection.getEndPoint(); EndPoint endp = httpConnection.getEndPoint();
@ -435,24 +439,4 @@ public abstract class ClientUpgradeRequest extends HttpRequest implements Respon
} }
} }
} }
private void updateWebSocketExtensionHeader()
{
HttpFields headers = getHeaders();
headers.remove(HttpHeader.SEC_WEBSOCKET_EXTENSIONS);
for (ExtensionConfig config : extensions)
{
headers.add(HttpHeader.SEC_WEBSOCKET_EXTENSIONS, config.getParameterizedName());
}
}
private void updateWebSocketSubProtocolHeader()
{
HttpFields headers = getHeaders();
headers.remove(HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL);
for (String protocol : subProtocols)
{
headers.add(HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL, protocol);
}
}
} }

View File

@ -18,6 +18,16 @@
package org.eclipse.jetty.websocket.core.server; package org.eclipse.jetty.websocket.core.server;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.ListIterator;
import java.util.Set;
import java.util.stream.Collectors;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.eclipse.jetty.http.HttpField; import org.eclipse.jetty.http.HttpField;
import org.eclipse.jetty.http.HttpHeader; import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.QuotedCSV; import org.eclipse.jetty.http.QuotedCSV;
@ -28,15 +38,6 @@ import org.eclipse.jetty.websocket.core.ExtensionConfig;
import org.eclipse.jetty.websocket.core.WebSocketExtensionRegistry; import org.eclipse.jetty.websocket.core.WebSocketExtensionRegistry;
import org.eclipse.jetty.websocket.core.internal.ExtensionStack; import org.eclipse.jetty.websocket.core.internal.ExtensionStack;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.ListIterator;
import java.util.Set;
import java.util.stream.Collectors;
public class Negotiation public class Negotiation
{ {
private final Request baseRequest; private final Request baseRequest;

View File

@ -43,6 +43,7 @@ import org.eclipse.jetty.server.Response;
import org.eclipse.jetty.util.log.Log; import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger; import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.websocket.core.Behavior; import org.eclipse.jetty.websocket.core.Behavior;
import org.eclipse.jetty.websocket.core.ExtensionConfig;
import org.eclipse.jetty.websocket.core.FrameHandler; import org.eclipse.jetty.websocket.core.FrameHandler;
import org.eclipse.jetty.websocket.core.WebSocketConstants; import org.eclipse.jetty.websocket.core.WebSocketConstants;
import org.eclipse.jetty.websocket.core.internal.Negotiated; import org.eclipse.jetty.websocket.core.internal.Negotiated;
@ -60,8 +61,7 @@ public final class RFC6455Handshaker implements Handshaker
private static final HttpField CONNECTION_UPGRADE = new PreEncodedHttpField(HttpHeader.CONNECTION, HttpHeader.UPGRADE.asString()); private static final HttpField CONNECTION_UPGRADE = new PreEncodedHttpField(HttpHeader.CONNECTION, HttpHeader.UPGRADE.asString());
private static final HttpField SERVER_VERSION = new PreEncodedHttpField(HttpHeader.SERVER, HttpConfiguration.SERVER_VERSION); private static final HttpField SERVER_VERSION = new PreEncodedHttpField(HttpHeader.SERVER, HttpConfiguration.SERVER_VERSION);
public boolean upgradeRequest(WebSocketNegotiator negotiator, HttpServletRequest request, public boolean upgradeRequest(WebSocketNegotiator negotiator, HttpServletRequest request, HttpServletResponse response,
HttpServletResponse response,
FrameHandler.Customizer defaultCustomizer) throws IOException FrameHandler.Customizer defaultCustomizer) throws IOException
{ {
Request baseRequest = Request.getBaseRequest(request); Request baseRequest = Request.getBaseRequest(request);
@ -153,22 +153,42 @@ public final class RFC6455Handshaker implements Handshaker
return false; return false;
} }
// Check if subprotocol negotiated // validate negotiated subprotocol
String subprotocol = negotiation.getSubprotocol(); String subprotocol = negotiation.getSubprotocol();
if (negotiation.getOfferedSubprotocols().size() > 0) if (subprotocol != null)
{ {
if (subprotocol == null)
{
// TODO: this message needs to be returned to Http Client
LOG.warn("not upgraded: no subprotocol selected from offered subprotocols {}: {}", negotiation.getOfferedSubprotocols(), baseRequest);
return false;
}
if (!negotiation.getOfferedSubprotocols().contains(subprotocol)) if (!negotiation.getOfferedSubprotocols().contains(subprotocol))
{ {
// TODO: this message needs to be returned to Http Client // TODO: this message needs to be returned to Http Client
LOG.warn("not upgraded: selected subprotocol {} not present in offered subprotocols {}: {}", subprotocol, negotiation.getOfferedSubprotocols(), LOG.warn("not upgraded: selected subprotocol {} not present in offered subprotocols {}: {}",
baseRequest); subprotocol, negotiation.getOfferedSubprotocols(), baseRequest);
return false;
}
}
else
{
if (!negotiation.getOfferedSubprotocols().isEmpty())
{
// TODO: this message needs to be returned to Http Client
LOG.warn("not upgraded: no subprotocol selected from offered subprotocols {}: {}",
negotiation.getOfferedSubprotocols(), baseRequest);
return false;
}
}
// validate negotiated extensions
negotiation.getOfferedExtensions();
for (ExtensionConfig config : negotiation.getNegotiatedExtensions())
{
long numMatch = negotiation.getOfferedExtensions().stream().filter(c -> config.getName().equalsIgnoreCase(c.getName())).count();
if (numMatch < 1)
{
LOG.warn("Upgrade failed: negotiated extension not requested {}: {}", config.getName(), baseRequest);
return false;
}
if (numMatch > 1)
{
LOG.warn("Upgrade failed: multiple negotiated extensions of the same name {}: {}", config.getName(), baseRequest);
return false; return false;
} }
} }

View File

@ -167,19 +167,22 @@ public class ServletUpgradeResponse
public void setExtensions(List<ExtensionConfig> configs) public void setExtensions(List<ExtensionConfig> configs)
{ {
// This validation is also done later in RFC6455Handshaker but it is better to fail earlier
for (ExtensionConfig config : configs) for (ExtensionConfig config : configs)
{ {
List<ExtensionConfig> collect = negotiation.getOfferedExtensions().stream() int matches = (int)negotiation.getOfferedExtensions().stream()
.filter(e -> e.getName().equals(config.getName())) .filter(e -> e.getName().equals(config.getName())).count();
.collect(Collectors.toList());
if (collect.size() == 1) switch (matches)
continue; {
else if (collect.size() == 0) case 0:
throw new IllegalArgumentException("Extension not a requested extension"); throw new IllegalArgumentException("Extension not a requested extension");
else if (collect.size() > 1) case 1:
continue;
default:
throw new IllegalArgumentException("Multiple extensions of the same name"); throw new IllegalArgumentException("Multiple extensions of the same name");
} }
}
negotiation.setNegotiatedExtensions(configs); negotiation.setNegotiatedExtensions(configs);
} }