Issue #3159 - Ensuring we follow permessage-deflate RSV1 rules in RFC7692

Signed-off-by: Joakim Erdfelt <joakim.erdfelt@gmail.com>
This commit is contained in:
Joakim Erdfelt 2019-03-01 16:55:12 -05:00
parent 82cd23f4f0
commit 6444446652
4 changed files with 175 additions and 26 deletions

View File

@ -25,6 +25,7 @@ import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.websocket.api.BadPayloadException;
import org.eclipse.jetty.websocket.api.BatchMode;
import org.eclipse.jetty.websocket.api.ProtocolException;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.eclipse.jetty.websocket.api.extensions.ExtensionConfig;
import org.eclipse.jetty.websocket.api.extensions.Frame;
@ -70,6 +71,12 @@ public class PerMessageDeflateExtension extends CompressExtension
nextIncomingFrame(frame);
return;
}
if (frame.getOpCode() == OpCode.CONTINUATION && frame.isRsv1())
{
// Per RFC7692 we MUST Fail the websocket connection
throw new ProtocolException("Invalid RSV1 set on permessage-deflate CONTINUATION frame");
}
ByteAccumulator accumulator = newByteAccumulator();

View File

@ -18,19 +18,14 @@
package org.eclipse.jetty.websocket.common.extensions;
import static org.hamcrest.MatcherAssert.assertThat;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.hamcrest.Matchers.is;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.LinkedBlockingDeque;
import org.eclipse.jetty.io.ByteBufferPool;
import org.eclipse.jetty.io.MappedByteBufferPool;
@ -52,9 +47,13 @@ import org.eclipse.jetty.websocket.common.io.FutureWriteCallback;
import org.eclipse.jetty.websocket.common.test.ByteBufferAssert;
import org.eclipse.jetty.websocket.common.test.IncomingFramesCapture;
import org.eclipse.jetty.websocket.common.test.OutgoingFramesCapture;
import org.junit.jupiter.api.Test;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
@SuppressWarnings("Duplicates")
public class FragmentExtensionTest
{
@ -155,7 +154,7 @@ public class FragmentExtensionTest
* @throws IOException on test failure
*/
@Test
public void testOutgoingFramesByMaxLength() throws IOException
public void testOutgoingFramesByMaxLength() throws IOException, InterruptedException
{
OutgoingFramesCapture capture = new OutgoingFramesCapture();
@ -197,11 +196,11 @@ public class FragmentExtensionTest
capture.assertFrameCount(len);
String prefix;
LinkedList<WebSocketFrame> frames = capture.getFrames();
LinkedBlockingDeque<WebSocketFrame> frames = capture.getFrames();
for (int i = 0; i < len; i++)
{
prefix = "Frame[" + i + "]";
WebSocketFrame actualFrame = frames.get(i);
WebSocketFrame actualFrame = frames.poll(1, SECONDS);
WebSocketFrame expectedFrame = expectedFrames.get(i);
// System.out.printf("actual: %s%n",actualFrame);
@ -266,11 +265,11 @@ public class FragmentExtensionTest
capture.assertFrameCount(len);
String prefix;
LinkedList<WebSocketFrame> frames = capture.getFrames();
LinkedBlockingDeque<WebSocketFrame> frames = capture.getFrames();
for (int i = 0; i < len; i++)
{
prefix = "Frame[" + i + "]";
WebSocketFrame actualFrame = frames.get(i);
WebSocketFrame actualFrame = frames.poll(1, SECONDS);
WebSocketFrame expectedFrame = expectedFrames.get(i);
// Validate Frame

View File

@ -18,20 +18,20 @@
package org.eclipse.jetty.websocket.common.extensions.compress;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.io.ByteBufferPool;
import org.eclipse.jetty.io.MappedByteBufferPool;
import org.eclipse.jetty.toolchain.test.ByteBufferAssert;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.TypeUtil;
import org.eclipse.jetty.websocket.api.BatchMode;
import org.eclipse.jetty.websocket.api.ProtocolException;
import org.eclipse.jetty.websocket.api.WebSocketPolicy;
import org.eclipse.jetty.websocket.api.extensions.ExtensionConfig;
import org.eclipse.jetty.websocket.api.extensions.Frame;
@ -42,12 +42,14 @@ import org.eclipse.jetty.websocket.common.extensions.ExtensionTool.Tester;
import org.eclipse.jetty.websocket.common.frames.ContinuationFrame;
import org.eclipse.jetty.websocket.common.frames.PingFrame;
import org.eclipse.jetty.websocket.common.frames.TextFrame;
import org.eclipse.jetty.websocket.common.test.ByteBufferAssert;
import org.eclipse.jetty.websocket.common.test.IncomingFramesCapture;
import org.eclipse.jetty.websocket.common.test.OutgoingFramesCapture;
import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertThrows;
/**
* Client side behavioral tests for permessage-deflate extension.
* <p>
@ -224,6 +226,56 @@ public class PerMessageDeflateExtensionTest extends AbstractExtensionTest
tester.assertHasFrames("Hello");
}
/**
* Decode fragmented message (3 parts: TEXT, CONTINUATION, CONTINUATION)
*/
@Test
public void testParseFragmentedMessage_Good()
{
Tester tester = clientExtensions.newTester("permessage-deflate");
tester.assertNegotiated("permessage-deflate");
tester.parseIncomingHex(// 1 message, 3 frame
"410C", // HEADER TEXT / fin=false / rsv1=true
"F248CDC9C95700000000FFFF",
"000B", // HEADER CONTINUATION / fin=false / rsv1=false
"0ACF2FCA4901000000FFFF",
"8003", // HEADER CONTINUATION / fin=true / rsv1=false
"520400"
);
Frame txtFrame = new TextFrame().setPayload("Hello ").setFin(false);
Frame con1Frame = new ContinuationFrame().setPayload("World").setFin(false);
Frame con2Frame = new ContinuationFrame().setPayload("!").setFin(true);
tester.assertHasFrames(txtFrame, con1Frame, con2Frame);
}
/**
* Decode fragmented message (3 parts: TEXT, CONTINUATION, CONTINUATION)
* <p>
* Continuation frames have RSV1 set, which MUST result in Failure
* </p>
*/
@Test
public void testParseFragmentedMessage_BadRsv1()
{
Tester tester = clientExtensions.newTester("permessage-deflate");
tester.assertNegotiated("permessage-deflate");
assertThrows(ProtocolException.class, () ->
tester.parseIncomingHex(// 1 message, 3 frame
"410C", // Header TEXT / fin=false / rsv1=true
"F248CDC9C95700000000FFFF", // Payload
"400B", // Header CONTINUATION / fin=false / rsv1=true
"0ACF2FCA4901000000FFFF", // Payload
"C003", // Header CONTINUATION / fin=true / rsv1=true
"520400" // Payload
));
}
/**
* Incoming PING (Control Frame) should pass through extension unmodified
*/
@ -261,6 +313,44 @@ public class PerMessageDeflateExtensionTest extends AbstractExtensionTest
ByteBufferAssert.assertEquals("Frame.payload", expected, actual.getPayload().slice());
}
/**
* Incoming Text Message fragmented into 3 pieces.
*/
@Test
public void testIncomingFragmented()
{
PerMessageDeflateExtension ext = new PerMessageDeflateExtension();
ext.setBufferPool(bufferPool);
ext.setPolicy(WebSocketPolicy.newServerPolicy());
ExtensionConfig config = ExtensionConfig.parse("permessage-deflate");
ext.setConfig(config);
// Setup capture of incoming frames
IncomingFramesCapture capture = new IncomingFramesCapture();
// Wire up stack
ext.setNextIncomingFrames(capture);
String payload = "Are you there?";
Frame ping = new PingFrame().setPayload(payload);
ext.incomingFrame(ping);
capture.assertFrameCount(1);
capture.assertHasFrame(OpCode.PING, 1);
WebSocketFrame actual = capture.getFrames().poll();
assertThat("Frame.opcode", actual.getOpCode(), is(OpCode.PING));
assertThat("Frame.fin", actual.isFin(), is(true));
assertThat("Frame.rsv1", actual.isRsv1(), is(false));
assertThat("Frame.rsv2", actual.isRsv2(), is(false));
assertThat("Frame.rsv3", actual.isRsv3(), is(false));
ByteBuffer expected = BufferUtil.toBuffer(payload, StandardCharsets.UTF_8);
assertThat("Frame.payloadLength", actual.getPayloadLength(), is(expected.remaining()));
ByteBufferAssert.assertEquals("Frame.payload", expected, actual.getPayload().slice());
}
/**
* Verify that incoming uncompressed frames are properly passed through
*/
@ -356,6 +446,58 @@ public class PerMessageDeflateExtensionTest extends AbstractExtensionTest
ByteBufferAssert.assertEquals("Frame.payload", expected, actual.getPayload().slice());
}
/**
* Outgoing Fragmented Message
* @throws IOException on test failure
*/
@Test
public void testOutgoingFragmentedMessage() throws IOException, InterruptedException
{
PerMessageDeflateExtension ext = new PerMessageDeflateExtension();
ext.setBufferPool(bufferPool);
ext.setPolicy(WebSocketPolicy.newServerPolicy());
ExtensionConfig config = ExtensionConfig.parse("permessage-deflate");
ext.setConfig(config);
// Setup capture of outgoing frames
OutgoingFramesCapture capture = new OutgoingFramesCapture();
// Wire up stack
ext.setNextOutgoingFrames(capture);
Frame txtFrame = new TextFrame().setPayload("Hello ").setFin(false);
Frame con1Frame = new ContinuationFrame().setPayload("World").setFin(false);
Frame con2Frame = new ContinuationFrame().setPayload("!").setFin(true);
ext.outgoingFrame(txtFrame, null, BatchMode.OFF);
ext.outgoingFrame(con1Frame, null, BatchMode.OFF);
ext.outgoingFrame(con2Frame, null, BatchMode.OFF);
capture.assertFrameCount(3);
WebSocketFrame capturedFrame;
capturedFrame = capture.getFrames().poll(1, TimeUnit.SECONDS);
assertThat("Frame.opcode", capturedFrame.getOpCode(), is(OpCode.TEXT));
assertThat("Frame.fin", capturedFrame.isFin(), is(false));
assertThat("Frame.rsv1", capturedFrame.isRsv1(), is(true));
assertThat("Frame.rsv2", capturedFrame.isRsv2(), is(false));
assertThat("Frame.rsv3", capturedFrame.isRsv3(), is(false));
capturedFrame = capture.getFrames().poll(1, TimeUnit.SECONDS);
assertThat("Frame.opcode", capturedFrame.getOpCode(), is(OpCode.CONTINUATION));
assertThat("Frame.fin", capturedFrame.isFin(), is(false));
assertThat("Frame.rsv1", capturedFrame.isRsv1(), is(false));
assertThat("Frame.rsv2", capturedFrame.isRsv2(), is(false));
assertThat("Frame.rsv3", capturedFrame.isRsv3(), is(false));
capturedFrame = capture.getFrames().poll(1, TimeUnit.SECONDS);
assertThat("Frame.opcode", capturedFrame.getOpCode(), is(OpCode.CONTINUATION));
assertThat("Frame.fin", capturedFrame.isFin(), is(true));
assertThat("Frame.rsv1", capturedFrame.isRsv1(), is(false));
assertThat("Frame.rsv2", capturedFrame.isRsv2(), is(false));
assertThat("Frame.rsv3", capturedFrame.isRsv3(), is(false));
}
@Test
public void testPyWebSocket_Client_NoContextTakeover_ThreeOra()
{

View File

@ -18,11 +18,7 @@
package org.eclipse.jetty.websocket.common.test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.is;
import java.util.LinkedList;
import java.util.concurrent.LinkedBlockingDeque;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.websocket.api.BatchMode;
@ -32,10 +28,14 @@ import org.eclipse.jetty.websocket.api.extensions.OutgoingFrames;
import org.eclipse.jetty.websocket.common.OpCode;
import org.eclipse.jetty.websocket.common.WebSocketFrame;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.is;
public class OutgoingFramesCapture implements OutgoingFrames
{
private LinkedList<WebSocketFrame> frames = new LinkedList<>();
private LinkedBlockingDeque<WebSocketFrame> frames = new LinkedBlockingDeque<>();
public void assertFrameCount(int expectedCount)
{
@ -60,11 +60,12 @@ public class OutgoingFramesCapture implements OutgoingFrames
public void dump()
{
System.out.printf("Captured %d outgoing writes%n",frames.size());
for (int i = 0; i < frames.size(); i++)
int i=0;
for (WebSocketFrame frame: frames)
{
Frame frame = frames.get(i);
System.out.printf("[%3d] %s%n",i,frame);
System.out.printf(" %s%n",BufferUtil.toDetailString(frame.getPayload()));
i++;
}
}
@ -81,7 +82,7 @@ public class OutgoingFramesCapture implements OutgoingFrames
return count;
}
public LinkedList<WebSocketFrame> getFrames()
public LinkedBlockingDeque<WebSocketFrame> getFrames()
{
return frames;
}