Issue #3462 - parse extensions and subprotocols from headers every time

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
Lachlan Roberts 2019-03-19 11:20:49 +11:00
parent 96027f5437
commit cdd3ed943c
2 changed files with 42 additions and 68 deletions

View File

@ -27,19 +27,15 @@ import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.eclipse.jetty.client.HttpResponse;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.QuotedCSV;
import org.eclipse.jetty.io.EndPoint;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.SharedBlockingCallback;
import org.eclipse.jetty.websocket.core.Behavior;
import org.eclipse.jetty.websocket.core.CloseStatus;
import org.eclipse.jetty.websocket.core.ExtensionConfig;
import org.eclipse.jetty.websocket.core.Frame;
import org.eclipse.jetty.websocket.core.FrameHandler;
import org.eclipse.jetty.websocket.core.client.ClientUpgradeRequest;
@ -77,22 +73,8 @@ public class NetworkFuzzer extends Fuzzer.Adapter implements Fuzzer, AutoCloseab
HttpFields fields = this.upgradeRequest.getHeaders();
requestHeaders.forEach((name, value) ->
{
if (HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL.toString().equalsIgnoreCase(name))
{
QuotedCSV subprotocols = new QuotedCSV(value);
upgradeRequest.setSubProtocols(subprotocols.getValues().stream().collect(Collectors.toList()));
}
else if (HttpHeader.SEC_WEBSOCKET_EXTENSIONS.toString().equalsIgnoreCase(name))
{
QuotedCSV extensions = new QuotedCSV(value);
upgradeRequest.setExtensions(
extensions.getValues().stream().map(ExtensionConfig::parse).collect(Collectors.toList()));
}
else
{
fields.remove(name);
fields.put(name, value);
}
fields.remove(name);
fields.put(name, value);
});
}
this.client.start();

View File

@ -20,7 +20,6 @@ package org.eclipse.jetty.websocket.core.client;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.CompletableFuture;
@ -45,6 +44,7 @@ import org.eclipse.jetty.http.HttpMethod;
import org.eclipse.jetty.http.HttpScheme;
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.http.HttpVersion;
import org.eclipse.jetty.http.QuotedCSV;
import org.eclipse.jetty.io.ByteBufferPool;
import org.eclipse.jetty.io.Connection;
import org.eclipse.jetty.io.EndPoint;
@ -83,14 +83,6 @@ public abstract class ClientUpgradeRequest extends HttpRequest implements Respon
protected final CompletableFuture<FrameHandler.CoreSession> futureCoreSession;
private final WebSocketCoreClient wsClient;
private List<UpgradeListener> upgradeListeners = new ArrayList<>();
/**
* Offered Extensions
*/
private List<ExtensionConfig> extensions = new ArrayList<>();
/**
* Offered SubProtocols
*/
private List<String> subProtocols = new ArrayList<>();
public ClientUpgradeRequest(WebSocketCoreClient webSocketClient, URI requestURI)
{
@ -133,47 +125,65 @@ public abstract class ClientUpgradeRequest extends HttpRequest implements Respon
public void addExtensions(ExtensionConfig... configs)
{
HttpFields headers = getHeaders();
for (ExtensionConfig config : configs)
{
this.extensions.add(config);
}
updateWebSocketExtensionHeader();
headers.add(HttpHeader.SEC_WEBSOCKET_EXTENSIONS, config.getParameterizedName());
}
public void addExtensions(String... configs)
{
this.extensions.addAll(ExtensionConfig.parseList(configs));
updateWebSocketExtensionHeader();
HttpFields headers = getHeaders();
for (String config : configs)
headers.add(HttpHeader.SEC_WEBSOCKET_EXTENSIONS, ExtensionConfig.parse(config).getParameterizedName());
}
public List<ExtensionConfig> getExtensions()
{
List<ExtensionConfig> extensions = new ArrayList<>();
getHeaders().stream()
.filter(field -> field.getName().equalsIgnoreCase(HttpHeader.SEC_WEBSOCKET_EXTENSIONS.toString()))
.forEach(field -> new QuotedCSV(field.getValue()).getValues().stream()
.map(ExtensionConfig::parse)
.forEach(extensions::add)
);
return extensions;
}
public void setExtensions(List<ExtensionConfig> configs)
{
this.extensions = configs;
updateWebSocketExtensionHeader();
HttpFields headers = getHeaders();
headers.remove(HttpHeader.SEC_WEBSOCKET_EXTENSIONS);
for (ExtensionConfig config : configs)
headers.add(HttpHeader.SEC_WEBSOCKET_EXTENSIONS, config.getParameterizedName());
}
public List<String> getSubProtocols()
{
return this.subProtocols;
List<String> subProtocols = new ArrayList<>();
getHeaders().stream()
.filter(field -> field.getName().equalsIgnoreCase(HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL.toString()))
.forEach(field -> subProtocols.addAll(new QuotedCSV(field.getValue()).getValues()));
return subProtocols;
}
public void setSubProtocols(String... protocols)
{
this.subProtocols.clear();
this.subProtocols.addAll(Arrays.asList(protocols));
updateWebSocketSubProtocolHeader();
HttpFields headers = getHeaders();
headers.remove(HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL);
for (String protocol : protocols)
headers.add(HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL, protocol);
}
public void setSubProtocols(List<String> protocols)
{
this.subProtocols.clear();
this.subProtocols.addAll(protocols);
updateWebSocketSubProtocolHeader();
HttpFields headers = getHeaders();
headers.remove(HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL);
for (String protocol : protocols)
headers.add(HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL, protocol);
}
@Override
@ -278,9 +288,10 @@ public abstract class ClientUpgradeRequest extends HttpRequest implements Respon
}
// Verify the Negotiated Extensions
List<ExtensionConfig> offeredExtensions = getExtensions();
for (ExtensionConfig config : extensions)
{
long numMatch = this.extensions.stream().filter(c -> config.getName().equalsIgnoreCase(c.getName())).count();
long numMatch = offeredExtensions.stream().filter(c -> config.getName().equalsIgnoreCase(c.getName())).count();
if (numMatch < 1)
throw new WebSocketException("Upgrade failed: Sec-WebSocket-Extensions contained extension not requested");
if (numMatch > 1)
@ -308,10 +319,11 @@ public abstract class ClientUpgradeRequest extends HttpRequest implements Respon
}
// Verify the negotiated subprotocol
if (negotiatedSubProtocol == null && !subProtocols.isEmpty())
List<String> offeredSubProtocols = getSubProtocols();
if (negotiatedSubProtocol == null && !offeredSubProtocols.isEmpty())
throw new WebSocketException("Upgrade failed: no subprotocol selected from offered subprotocols ");
if (negotiatedSubProtocol != null && !subProtocols.contains(negotiatedSubProtocol))
throw new WebSocketException("Upgrade failed: subprotocol [" + negotiatedSubProtocol + "] not found in offered subprotocols " + subProtocols);
if (negotiatedSubProtocol != null && !offeredSubProtocols.contains(negotiatedSubProtocol))
throw new WebSocketException("Upgrade failed: subprotocol [" + negotiatedSubProtocol + "] not found in offered subprotocols " + offeredSubProtocols);
// We can upgrade
EndPoint endp = httpConnection.getEndPoint();
@ -436,24 +448,4 @@ public abstract class ClientUpgradeRequest extends HttpRequest implements Respon
}
}
}
private void updateWebSocketExtensionHeader()
{
HttpFields headers = getHeaders();
headers.remove(HttpHeader.SEC_WEBSOCKET_EXTENSIONS);
for (ExtensionConfig config : extensions)
{
headers.add(HttpHeader.SEC_WEBSOCKET_EXTENSIONS, config.getParameterizedName());
}
}
private void updateWebSocketSubProtocolHeader()
{
HttpFields headers = getHeaders();
headers.remove(HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL);
for (String protocol : subProtocols)
{
headers.add(HttpHeader.SEC_WEBSOCKET_SUBPROTOCOL, protocol);
}
}
}