From 9cd96c3c93c39edd9dcb37d104804d2233071749 Mon Sep 17 00:00:00 2001 From: Lachlan Roberts Date: Wed, 27 Feb 2019 14:09:42 +1100 Subject: [PATCH] Issue #3159 - Ensuring we follow permessage-deflate RSV1 rules in RFC7692 Signed-off-by: Lachlan Roberts --- .../compress/PerMessageDeflateExtension.java | 18 +++--- .../core/extensions/ExtensionTool.java | 3 +- .../extensions/FragmentExtensionTest.java | 19 +++--- .../PerMessageDeflateExtensionTest.java | 64 +++++++++++++++++-- 4 files changed, 81 insertions(+), 23 deletions(-) diff --git a/jetty-websocket/websocket-core/src/main/java/org/eclipse/jetty/websocket/core/internal/compress/PerMessageDeflateExtension.java b/jetty-websocket/websocket-core/src/main/java/org/eclipse/jetty/websocket/core/internal/compress/PerMessageDeflateExtension.java index dc97224481a..e5f47588fc8 100644 --- a/jetty-websocket/websocket-core/src/main/java/org/eclipse/jetty/websocket/core/internal/compress/PerMessageDeflateExtension.java +++ b/jetty-websocket/websocket-core/src/main/java/org/eclipse/jetty/websocket/core/internal/compress/PerMessageDeflateExtension.java @@ -63,9 +63,17 @@ public class PerMessageDeflateExtension extends CompressExtension // This extension requires the RSV1 bit set only in the first frame. // Subsequent continuation frames don't have RSV1 set, but are compressed. - if (OpCode.isDataFrame(frame.getOpCode()) && frame.getOpCode() != OpCode.CONTINUATION) + switch (frame.getOpCode()) { - incomingCompressed = frame.isRsv1(); + case OpCode.TEXT: + case OpCode.BINARY: + incomingCompressed = frame.isRsv1(); + break; + + case OpCode.CONTINUATION: + if (frame.isRsv1()) + callback.failed(new ProtocolException("Invalid RSV1 set on permessage-deflate CONTINUATION frame")); + break; } if (OpCode.isControlFrame(frame.getOpCode()) || !incomingCompressed) @@ -74,12 +82,6 @@ public class PerMessageDeflateExtension extends CompressExtension return; } - if (frame.getOpCode() == OpCode.CONTINUATION && frame.isRsv1()) - { - // Per RFC7692 we MUST Fail the websocket connection - throw new ProtocolException("Invalid RSV1 set on permessage-deflate CONTINUATION frame"); - } - //TODO fix this to use long instead of int if (getWebSocketChannel().getMaxFrameSize() > Integer.MAX_VALUE) throw new IllegalArgumentException("maxFrameSize too large for ByteAccumulator"); diff --git a/jetty-websocket/websocket-core/src/test/java/org/eclipse/jetty/websocket/core/extensions/ExtensionTool.java b/jetty-websocket/websocket-core/src/test/java/org/eclipse/jetty/websocket/core/extensions/ExtensionTool.java index aa2b680b7b0..519b4d8b437 100644 --- a/jetty-websocket/websocket-core/src/test/java/org/eclipse/jetty/websocket/core/extensions/ExtensionTool.java +++ b/jetty-websocket/websocket-core/src/test/java/org/eclipse/jetty/websocket/core/extensions/ExtensionTool.java @@ -41,6 +41,7 @@ import org.eclipse.jetty.websocket.core.internal.Negotiated; import org.eclipse.jetty.websocket.core.internal.Parser; import org.eclipse.jetty.websocket.core.internal.WebSocketChannel; import org.hamcrest.Matchers; +import org.junit.jupiter.api.Assertions; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; @@ -95,7 +96,7 @@ public class ExtensionTool Frame frame = parser.parse(buffer); if (frame == null) break; - ext.onFrame(frame, Callback.NOOP); + ext.onFrame(frame, Callback.from(()->{}, Assertions::fail)); } } } diff --git a/jetty-websocket/websocket-core/src/test/java/org/eclipse/jetty/websocket/core/extensions/FragmentExtensionTest.java b/jetty-websocket/websocket-core/src/test/java/org/eclipse/jetty/websocket/core/extensions/FragmentExtensionTest.java index d6dfa9112de..855e0719d0f 100644 --- a/jetty-websocket/websocket-core/src/test/java/org/eclipse/jetty/websocket/core/extensions/FragmentExtensionTest.java +++ b/jetty-websocket/websocket-core/src/test/java/org/eclipse/jetty/websocket/core/extensions/FragmentExtensionTest.java @@ -35,6 +35,9 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.LinkedList; import java.util.List; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; @@ -126,10 +129,10 @@ public class FragmentExtensionTest extends AbstractExtensionTest /** * Verify that outgoing text frames are fragmented by the maxLength configuration. * - * @throws IOException on test failure + * @throws Exception on test failure */ @Test - public void testOutgoingFramesByMaxLength() throws IOException + public void testOutgoingFramesByMaxLength() throws Exception { OutgoingFramesCapture capture = new OutgoingFramesCapture(); @@ -169,11 +172,11 @@ public class FragmentExtensionTest extends AbstractExtensionTest capture.assertFrameCount(len); String prefix; - LinkedList frames = new LinkedList<>(capture.frames); + BlockingQueue frames = capture.frames; for (int i = 0; i < len; i++) { prefix = "Frame[" + i + "]"; - Frame actualFrame = frames.get(i); + Frame actualFrame = frames.poll(1, TimeUnit.SECONDS); Frame expectedFrame = expectedFrames.get(i); // System.out.printf("actual: %s%n",actualFrame); @@ -198,10 +201,10 @@ public class FragmentExtensionTest extends AbstractExtensionTest /** * Verify that outgoing text frames are fragmented by default configuration * - * @throws IOException on test failure + * @throws Exception on test failure */ @Test - public void testOutgoingFramesDefaultConfig() throws IOException + public void testOutgoingFramesDefaultConfig() throws Exception { OutgoingFramesCapture capture = new OutgoingFramesCapture(); @@ -236,11 +239,11 @@ public class FragmentExtensionTest extends AbstractExtensionTest capture.assertFrameCount(len); String prefix; - LinkedList frames = new LinkedList<>(capture.frames); + BlockingQueue frames = capture.frames; for (int i = 0; i < len; i++) { prefix = "Frame[" + i + "]"; - Frame actualFrame = frames.get(i); + Frame actualFrame = frames.poll(1, TimeUnit.SECONDS); Frame expectedFrame = expectedFrames.get(i); // Validate Frame diff --git a/jetty-websocket/websocket-core/src/test/java/org/eclipse/jetty/websocket/core/extensions/PerMessageDeflateExtensionTest.java b/jetty-websocket/websocket-core/src/test/java/org/eclipse/jetty/websocket/core/extensions/PerMessageDeflateExtensionTest.java index 8e10baa967c..6527e57efea 100644 --- a/jetty-websocket/websocket-core/src/test/java/org/eclipse/jetty/websocket/core/extensions/PerMessageDeflateExtensionTest.java +++ b/jetty-websocket/websocket-core/src/test/java/org/eclipse/jetty/websocket/core/extensions/PerMessageDeflateExtensionTest.java @@ -38,7 +38,9 @@ import org.eclipse.jetty.websocket.core.internal.compress.CompressExtension; import org.eclipse.jetty.websocket.core.internal.compress.PerMessageDeflateExtension; import org.junit.jupiter.api.Test; +import java.util.concurrent.TimeUnit; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -240,9 +242,9 @@ public class PerMessageDeflateExtensionTest extends AbstractExtensionTest "520400" ); - Frame txtFrame = new Frame(OpCode.TEXT).setPayload("Hello ").setFin(false); - Frame con1Frame = new Frame(OpCode.CONTINUATION).setPayload("World").setFin(false); - Frame con2Frame = new Frame(OpCode.CONTINUATION).setPayload("!").setFin(true); + Frame txtFrame = new Frame(OpCode.TEXT, false, "Hello "); + Frame con1Frame = new Frame(OpCode.CONTINUATION, false, "World"); + Frame con2Frame = new Frame(OpCode.CONTINUATION, true, "!"); tester.assertHasFrames(txtFrame, con1Frame, con2Frame); } @@ -260,7 +262,7 @@ public class PerMessageDeflateExtensionTest extends AbstractExtensionTest tester.assertNegotiated("permessage-deflate"); - assertThrows(ProtocolException.class, () -> + Throwable t = assertThrows(Throwable.class, () -> tester.parseIncomingHex(// 1 message, 3 frame "410C", // Header TEXT / fin=false / rsv1=true "F248CDC9C95700000000FFFF", // Payload @@ -269,6 +271,9 @@ public class PerMessageDeflateExtensionTest extends AbstractExtensionTest "C003", // Header CONTINUATION / fin=true / rsv1=true "520400" // Payload )); + + assertThat(t.getCause(), instanceOf(ProtocolException.class)); + assertThat(t.getCause().getMessage(), is("Invalid RSV1 set on permessage-deflate CONTINUATION frame")); } /** @@ -341,8 +346,6 @@ public class PerMessageDeflateExtensionTest extends AbstractExtensionTest ByteBufferAssert.assertEquals("Frame.payload", expected, actual.getPayload().slice()); } - /** - /** * Verify that incoming uncompressed frames are properly passed through */ @@ -435,6 +438,55 @@ public class PerMessageDeflateExtensionTest extends AbstractExtensionTest ByteBufferAssert.assertEquals("Frame.payload", expected, actual.getPayload().slice()); } + /** + * Outgoing Fragmented Message + * @throws IOException on test failure + */ + @Test + public void testOutgoingFragmentedMessage() throws IOException, InterruptedException + { + PerMessageDeflateExtension ext = new PerMessageDeflateExtension(); + ext.init(ExtensionConfig.parse("permessage-deflate"), bufferPool); + + // Setup capture of outgoing frames + OutgoingFramesCapture capture = new OutgoingFramesCapture(); + + // Wire up stack + ext.setNextOutgoingFrames(capture); + + Frame txtFrame = new Frame(OpCode.TEXT, false, "Hello "); + Frame con1Frame = new Frame(OpCode.CONTINUATION, false, "World"); + Frame con2Frame = new Frame(OpCode.CONTINUATION, true, "!"); + ext.sendFrame(txtFrame, Callback.NOOP, false); + ext.sendFrame(con1Frame, Callback.NOOP, false); + ext.sendFrame(con2Frame, Callback.NOOP, false); + + capture.assertFrameCount(3); + + Frame capturedFrame; + + capturedFrame = capture.frames.poll(1, TimeUnit.SECONDS); + assertThat("Frame.opcode", capturedFrame.getOpCode(), is(OpCode.TEXT)); + assertThat("Frame.fin", capturedFrame.isFin(), is(false)); + assertThat("Frame.rsv1", capturedFrame.isRsv1(), is(true)); + assertThat("Frame.rsv2", capturedFrame.isRsv2(), is(false)); + assertThat("Frame.rsv3", capturedFrame.isRsv3(), is(false)); + + capturedFrame = capture.frames.poll(1, TimeUnit.SECONDS); + assertThat("Frame.opcode", capturedFrame.getOpCode(), is(OpCode.CONTINUATION)); + assertThat("Frame.fin", capturedFrame.isFin(), is(false)); + assertThat("Frame.rsv1", capturedFrame.isRsv1(), is(false)); + assertThat("Frame.rsv2", capturedFrame.isRsv2(), is(false)); + assertThat("Frame.rsv3", capturedFrame.isRsv3(), is(false)); + + capturedFrame = capture.frames.poll(1, TimeUnit.SECONDS); + assertThat("Frame.opcode", capturedFrame.getOpCode(), is(OpCode.CONTINUATION)); + assertThat("Frame.fin", capturedFrame.isFin(), is(true)); + assertThat("Frame.rsv1", capturedFrame.isRsv1(), is(false)); + assertThat("Frame.rsv2", capturedFrame.isRsv2(), is(false)); + assertThat("Frame.rsv3", capturedFrame.isRsv3(), is(false)); + } + @Test public void testPyWebSocket_Client_NoContextTakeover_ThreeOra() {