From 736a576f7565a24524e9dd9f713737ba225b87f3 Mon Sep 17 00:00:00 2001 From: Lachlan Roberts Date: Fri, 5 May 2023 18:51:37 +1000 Subject: [PATCH] Review of websocket parser, improve testing & comments. Signed-off-by: Lachlan Roberts --- .../jetty/websocket/core/internal/Parser.java | 32 +++++++--- .../jetty/websocket/core/ParserTest.java | 58 +++++++++++++++++++ 2 files changed, 81 insertions(+), 9 deletions(-) diff --git a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/Parser.java b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/Parser.java index 5494662688d..63e19d7ae74 100644 --- a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/Parser.java +++ b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/Parser.java @@ -55,6 +55,7 @@ public class Parser private int cursor; private byte[] mask; private int payloadLength; + private long longLengthAccumulator; private ByteBuffer aggregate; public Parser(ByteBufferPool bufferPool) @@ -68,6 +69,11 @@ public class Parser this.configuration = configuration; } + public int getPayloadLength() + { + return payloadLength; + } + public void reset() { state = State.START; @@ -75,7 +81,8 @@ public class Parser mask = null; cursor = 0; aggregate = null; - payloadLength = -1; + payloadLength = 0; + longLengthAccumulator = 0; } /** @@ -148,9 +155,12 @@ public class Parser { byte b = buffer.get(); --cursor; - payloadLength |= (b & 0xFF) << (8 * cursor); + longLengthAccumulator |= (long)(b & 0xFF) << (8 * cursor); if (cursor == 0) { + if (longLengthAccumulator > Integer.MAX_VALUE || longLengthAccumulator < 0) + throw new MessageTooLargeException("Frame payload exceeded integer max value"); + payloadLength = Math.toIntExact(longLengthAccumulator); if (mask != null) { state = State.MASK; @@ -250,6 +260,9 @@ public class Parser protected void checkFrameSize(byte opcode, int payloadLength) throws MessageTooLargeException, ProtocolException { + if (payloadLength < 0) + throw new IllegalArgumentException("Invalid payloadLength"); + if (OpCode.isControlFrame(opcode)) { if (payloadLength > Frame.MAX_CONTROL_PAYLOAD) @@ -287,7 +300,7 @@ public class Parser { int shift = fragmentSize % 4; nextMask = new byte[4]; - nextMask[0] = mask[(0 + shift) % 4]; + nextMask[0] = mask[(shift) % 4]; nextMask[1] = mask[(1 + shift) % 4]; nextMask[2] = mask[(2 + shift) % 4]; nextMask[3] = mask[(3 + shift) % 4]; @@ -316,6 +329,7 @@ public class Parser boolean isDataFrame = OpCode.isDataFrame(OpCode.getOpCode(firstByte)); // Always autoFragment data frames if payloadLength is greater than maxFrameSize. + // We have already checked payload size in checkFrameSize, so we know we can autoFragment if larger than maxFrameSize. long maxFrameSize = configuration.getMaxFrameSize(); if (maxFrameSize > 0 && isDataFrame && payloadLength > maxFrameSize) return autoFragment(buffer, (int)Math.min(available, maxFrameSize)); @@ -324,12 +338,12 @@ public class Parser { if (available < payloadLength) { - // not enough to complete this frame - // Can we auto-fragment + // Not enough data to complete this frame, can we auto-fragment? if (configuration.isAutoFragment() && isDataFrame) return autoFragment(buffer, available); - // No space in the buffer, so we have to copy the partial payload + // No space in the buffer, so we have to copy the partial payload. + // The size of this allocation is limited by the maxFrameSize. aggregate = bufferPool.acquire(payloadLength, false); BufferUtil.append(aggregate, buffer); return null; @@ -337,15 +351,15 @@ public class Parser if (available == payloadLength) { - // All the available data is for this frame and completes it + // All the available data is for this frame and completes it. ParsedFrame frame = newFrame(firstByte, mask, buffer.slice(), false); buffer.position(buffer.limit()); state = State.START; return frame; } - // The buffer contains all the data for this frame and for subsequent frames - // Copy the just the first part of the buffer as frame payload + // The buffer contains all the data for this frame and for subsequent frames. + // Copy just the first part of the buffer as the frame payload. int limit = buffer.limit(); int end = buffer.position() + payloadLength; buffer.limit(end); diff --git a/jetty-websocket/websocket-core-tests/src/test/java/org/eclipse/jetty/websocket/core/ParserTest.java b/jetty-websocket/websocket-core-tests/src/test/java/org/eclipse/jetty/websocket/core/ParserTest.java index d6c1b204719..0994691fd7e 100644 --- a/jetty-websocket/websocket-core-tests/src/test/java/org/eclipse/jetty/websocket/core/ParserTest.java +++ b/jetty-websocket/websocket-core-tests/src/test/java/org/eclipse/jetty/websocket/core/ParserTest.java @@ -20,6 +20,7 @@ import java.util.Arrays; import java.util.List; import java.util.concurrent.TimeUnit; +import org.eclipse.jetty.io.NullByteBufferPool; import org.eclipse.jetty.toolchain.test.Hex; import org.eclipse.jetty.util.BufferUtil; import org.eclipse.jetty.util.StringUtil; @@ -33,6 +34,7 @@ import org.junit.jupiter.api.Test; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.lessThan; @@ -40,6 +42,7 @@ import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.sameInstance; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -227,6 +230,61 @@ public class ParserTest assertThat("Frame.payloadLength", pActual.getPayloadLength(), is(length)); } + private ByteBuffer toBuffer(long l) + { + ByteBuffer buffer = BufferUtil.allocate(Long.BYTES); + BufferUtil.clearToFill(buffer); + buffer.putLong(l); + BufferUtil.flipToFlush(buffer, 0); + return buffer; + } + + @Test + public void testLargeFrame() + { + ByteBuffer expected = ByteBuffer.allocate(65); + + expected.put(new byte[]{(byte)0x82}); + byte b = 0x7F; // no masking + expected.put(b); + expected.put(toBuffer(Integer.MAX_VALUE)); + expected.flip(); + + Parser parser = new Parser(new NullByteBufferPool()); + assertNull(parser.parse(expected)); + assertThat(parser.getPayloadLength(), equalTo(Integer.MAX_VALUE)); + } + + @Test + public void testFrameTooLarge() + { + ByteBuffer expected = ByteBuffer.allocate(65); + + expected.put(new byte[]{(byte)0x82}); + byte b = 0x7F; // no masking + expected.put(b); + expected.put(toBuffer(Integer.MAX_VALUE + 1L)); + expected.flip(); + + Parser parser = new Parser(new NullByteBufferPool()); + assertThrows(MessageTooLargeException.class, () -> parser.parse(expected)); + } + + @Test + public void testLargestFrame() + { + ByteBuffer expected = ByteBuffer.allocate(65); + + expected.put(new byte[]{(byte)0x82}); + byte b = 0x7F; // no masking + expected.put(b); + expected.put(new byte[]{(byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF}); + expected.flip(); + + Parser parser = new Parser(new NullByteBufferPool()); + assertThrows(MessageTooLargeException.class, () -> parser.parse(expected)); + } + /** * From Autobahn WebSocket Server Testcase 1.2.6 */