Review of websocket parser, improve testing & comments.

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
Lachlan Roberts 2023-05-05 18:51:37 +10:00
parent 12581c0ea0
commit 736a576f75
2 changed files with 81 additions and 9 deletions

View File

@ -55,6 +55,7 @@ public class Parser
private int cursor;
private byte[] mask;
private int payloadLength;
private long longLengthAccumulator;
private ByteBuffer aggregate;
public Parser(ByteBufferPool bufferPool)
@ -68,6 +69,11 @@ public class Parser
this.configuration = configuration;
}
public int getPayloadLength()
{
return payloadLength;
}
public void reset()
{
state = State.START;
@ -75,7 +81,8 @@ public class Parser
mask = null;
cursor = 0;
aggregate = null;
payloadLength = -1;
payloadLength = 0;
longLengthAccumulator = 0;
}
/**
@ -148,9 +155,12 @@ public class Parser
{
byte b = buffer.get();
--cursor;
payloadLength |= (b & 0xFF) << (8 * cursor);
longLengthAccumulator |= (long)(b & 0xFF) << (8 * cursor);
if (cursor == 0)
{
if (longLengthAccumulator > Integer.MAX_VALUE || longLengthAccumulator < 0)
throw new MessageTooLargeException("Frame payload exceeded integer max value");
payloadLength = Math.toIntExact(longLengthAccumulator);
if (mask != null)
{
state = State.MASK;
@ -250,6 +260,9 @@ public class Parser
protected void checkFrameSize(byte opcode, int payloadLength) throws MessageTooLargeException, ProtocolException
{
if (payloadLength < 0)
throw new IllegalArgumentException("Invalid payloadLength");
if (OpCode.isControlFrame(opcode))
{
if (payloadLength > Frame.MAX_CONTROL_PAYLOAD)
@ -287,7 +300,7 @@ public class Parser
{
int shift = fragmentSize % 4;
nextMask = new byte[4];
nextMask[0] = mask[(0 + shift) % 4];
nextMask[0] = mask[(shift) % 4];
nextMask[1] = mask[(1 + shift) % 4];
nextMask[2] = mask[(2 + shift) % 4];
nextMask[3] = mask[(3 + shift) % 4];
@ -316,6 +329,7 @@ public class Parser
boolean isDataFrame = OpCode.isDataFrame(OpCode.getOpCode(firstByte));
// Always autoFragment data frames if payloadLength is greater than maxFrameSize.
// We have already checked payload size in checkFrameSize, so we know we can autoFragment if larger than maxFrameSize.
long maxFrameSize = configuration.getMaxFrameSize();
if (maxFrameSize > 0 && isDataFrame && payloadLength > maxFrameSize)
return autoFragment(buffer, (int)Math.min(available, maxFrameSize));
@ -324,12 +338,12 @@ public class Parser
{
if (available < payloadLength)
{
// not enough to complete this frame
// Can we auto-fragment
// Not enough data to complete this frame, can we auto-fragment?
if (configuration.isAutoFragment() && isDataFrame)
return autoFragment(buffer, available);
// No space in the buffer, so we have to copy the partial payload
// No space in the buffer, so we have to copy the partial payload.
// The size of this allocation is limited by the maxFrameSize.
aggregate = bufferPool.acquire(payloadLength, false);
BufferUtil.append(aggregate, buffer);
return null;
@ -337,15 +351,15 @@ public class Parser
if (available == payloadLength)
{
// All the available data is for this frame and completes it
// All the available data is for this frame and completes it.
ParsedFrame frame = newFrame(firstByte, mask, buffer.slice(), false);
buffer.position(buffer.limit());
state = State.START;
return frame;
}
// The buffer contains all the data for this frame and for subsequent frames
// Copy the just the first part of the buffer as frame payload
// The buffer contains all the data for this frame and for subsequent frames.
// Copy just the first part of the buffer as the frame payload.
int limit = buffer.limit();
int end = buffer.position() + payloadLength;
buffer.limit(end);

View File

@ -20,6 +20,7 @@ import java.util.Arrays;
import java.util.List;
import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.io.NullByteBufferPool;
import org.eclipse.jetty.toolchain.test.Hex;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.StringUtil;
@ -33,6 +34,7 @@ import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan;
@ -40,6 +42,7 @@ import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.sameInstance;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
@ -227,6 +230,61 @@ public class ParserTest
assertThat("Frame.payloadLength", pActual.getPayloadLength(), is(length));
}
private ByteBuffer toBuffer(long l)
{
ByteBuffer buffer = BufferUtil.allocate(Long.BYTES);
BufferUtil.clearToFill(buffer);
buffer.putLong(l);
BufferUtil.flipToFlush(buffer, 0);
return buffer;
}
@Test
public void testLargeFrame()
{
ByteBuffer expected = ByteBuffer.allocate(65);
expected.put(new byte[]{(byte)0x82});
byte b = 0x7F; // no masking
expected.put(b);
expected.put(toBuffer(Integer.MAX_VALUE));
expected.flip();
Parser parser = new Parser(new NullByteBufferPool());
assertNull(parser.parse(expected));
assertThat(parser.getPayloadLength(), equalTo(Integer.MAX_VALUE));
}
@Test
public void testFrameTooLarge()
{
ByteBuffer expected = ByteBuffer.allocate(65);
expected.put(new byte[]{(byte)0x82});
byte b = 0x7F; // no masking
expected.put(b);
expected.put(toBuffer(Integer.MAX_VALUE + 1L));
expected.flip();
Parser parser = new Parser(new NullByteBufferPool());
assertThrows(MessageTooLargeException.class, () -> parser.parse(expected));
}
@Test
public void testLargestFrame()
{
ByteBuffer expected = ByteBuffer.allocate(65);
expected.put(new byte[]{(byte)0x82});
byte b = 0x7F; // no masking
expected.put(b);
expected.put(new byte[]{(byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF});
expected.flip();
Parser parser = new Parser(new NullByteBufferPool());
assertThrows(MessageTooLargeException.class, () -> parser.parse(expected));
}
/**
* From Autobahn WebSocket Server Testcase 1.2.6
*/