Issue #4571 - review UTF-8 validation in MessageSinks

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
Lachlan Roberts 2020-02-20 16:55:02 +11:00
parent b2eddff228
commit adbb3f165e
9 changed files with 324 additions and 42 deletions

View File

@ -221,6 +221,12 @@ public class Frame
return payload;
}
/**
* Get the payload of the frame as a UTF-8 string.
* <p>Should only be used in testing, does not validate the
* UTF-8 and a non fin frame can contain partial UTF-8 characters.</p>
* @return the payload as a UTF-8 string.
*/
public String getPayloadAsUTF8()
{
if (payload == null)

View File

@ -50,7 +50,6 @@ public class ByteArrayMessageSink extends AbstractMessageSink
}
}
@SuppressWarnings("Duplicates")
@Override
public void accept(Frame frame, Callback callback)
{

View File

@ -51,7 +51,6 @@ public class ByteBufferMessageSink extends AbstractMessageSink
}
}
@SuppressWarnings("Duplicates")
@Override
public void accept(Frame frame, Callback callback)
{

View File

@ -43,7 +43,6 @@ public class PartialByteArrayMessageSink extends AbstractMessageSink
}
}
@SuppressWarnings("Duplicates")
@Override
public void accept(Frame frame, Callback callback)
{

View File

@ -19,9 +19,7 @@
package org.eclipse.jetty.websocket.util.messages;
import java.lang.invoke.MethodHandle;
import java.nio.ByteBuffer;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.Frame;
@ -41,17 +39,13 @@ public class PartialByteBufferMessageSink extends AbstractMessageSink
*/
}
@SuppressWarnings("Duplicates")
@Override
public void accept(Frame frame, Callback callback)
{
try
{
if (frame.hasPayload() || frame.isFin())
{
ByteBuffer buffer = frame.hasPayload() ? frame.getPayload() : BufferUtil.EMPTY_BUFFER;
methodHandle.invoke(buffer, frame.isFin());
}
methodHandle.invoke(frame.getPayload(), frame.isFin());
callback.succeeded();
}

View File

@ -22,26 +22,37 @@ import java.lang.invoke.MethodHandle;
import java.util.Objects;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.Utf8StringBuilder;
import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.Frame;
public class PartialStringMessageSink extends AbstractMessageSink
{
private Utf8StringBuilder out;
public PartialStringMessageSink(CoreSession session, MethodHandle methodHandle)
{
super(session, methodHandle);
Objects.requireNonNull(methodHandle, "MethodHandle");
}
@SuppressWarnings("Duplicates")
@Override
public void accept(Frame frame, Callback callback)
{
try
{
if (frame.hasPayload() || frame.isFin())
if (out == null)
out = new Utf8StringBuilder(session.getInputBufferSize());
out.append(frame.getPayload());
if (frame.isFin())
{
methodHandle.invoke(frame.getPayloadAsUTF8(), frame.isFin());
methodHandle.invoke(out.toString(), true);
out = null;
}
else
{
methodHandle.invoke(out.takePartialString(), false);
}
callback.succeeded();

View File

@ -18,9 +18,7 @@
package org.eclipse.jetty.websocket.util.messages;
import java.io.IOException;
import java.lang.invoke.MethodHandle;
import java.nio.ByteBuffer;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.Utf8StringBuilder;
@ -30,7 +28,6 @@ import org.eclipse.jetty.websocket.core.exception.MessageTooLargeException;
public class StringMessageSink extends AbstractMessageSink
{
private static final int BUFFER_SIZE = 1024;
private Utf8StringBuilder out;
private int size;
@ -53,19 +50,10 @@ public class StringMessageSink extends AbstractMessageSink
size, maxTextMessageSize));
}
// If we are fin and out has not been created we don't need to aggregate.
if (frame.isFin() && (out == null))
{
if (frame.hasPayload())
methodHandle.invoke(frame.getPayloadAsUTF8());
else
methodHandle.invoke("");
if (out == null)
out = new Utf8StringBuilder(session.getInputBufferSize());
callback.succeeded();
return;
}
aggregatePayload(frame);
out.append(frame.getPayload());
if (frame.isFin())
methodHandle.invoke(out.toString());
@ -85,18 +73,4 @@ public class StringMessageSink extends AbstractMessageSink
}
}
}
private void aggregatePayload(Frame frame) throws IOException
{
if (frame.hasPayload())
{
ByteBuffer payload = frame.getPayload();
if (out == null)
out = new Utf8StringBuilder(BUFFER_SIZE);
// allow for fast fail of BAD utf (incomplete utf will trigger on messageComplete)
out.append(payload);
}
}
}

View File

@ -0,0 +1,153 @@
//
// ========================================================================
// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under
// the terms of the Eclipse Public License 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0
//
// This Source Code may also be made available under the following
// Secondary Licenses when the conditions for such availability set
// forth in the Eclipse Public License, v. 2.0 are satisfied:
// the Apache License v2.0 which is available at
// https://www.apache.org/licenses/LICENSE-2.0
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.websocket.util;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.util.BlockingArrayQueue;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.FutureCallback;
import org.eclipse.jetty.util.Utf8Appendable;
import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.Frame;
import org.eclipse.jetty.websocket.core.OpCode;
import org.eclipse.jetty.websocket.util.messages.PartialStringMessageSink;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class PartialStringMessageSinkTest
{
private CoreSession coreSession = new CoreSession.Empty();
private OnMessageEndpoint endpoint = new OnMessageEndpoint();
private PartialStringMessageSink messageSink;
@BeforeEach
public void before() throws Exception
{
messageSink = new PartialStringMessageSink(coreSession, endpoint.getMethodHandle());
}
@Test
public void testValidUtf8() throws Exception
{
ByteBuffer utf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90, (byte)0x8D, (byte)0x88});
FutureCallback callback = new FutureCallback();
messageSink.accept(new Frame(OpCode.TEXT, utf8Payload).setFin(true), callback);
callback.block(5, TimeUnit.SECONDS);
List<String> message = Objects.requireNonNull(endpoint.messages.poll(5, TimeUnit.SECONDS));
assertThat(message.size(), is(1));
assertThat(message.get(0), is("\uD800\uDF48"));
}
@Test
public void testUtf8Continuation() throws Exception
{
ByteBuffer firstUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90});
ByteBuffer continuationUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0x8D, (byte)0x88});
FutureCallback callback = new FutureCallback();
messageSink.accept(new Frame(OpCode.TEXT, firstUtf8Payload).setFin(false), callback);
callback.block(5, TimeUnit.SECONDS);
callback = new FutureCallback();
messageSink.accept(new Frame(OpCode.TEXT, continuationUtf8Payload).setFin(true), callback);
callback.block(5, TimeUnit.SECONDS);
List<String> message = Objects.requireNonNull(endpoint.messages.poll(5, TimeUnit.SECONDS));
assertThat(message.size(), is(2));
assertThat(message.get(0), is(""));
assertThat(message.get(1), is("\uD800\uDF48"));
}
@Test
public void testInvalidSingleFrameUtf8() throws Exception
{
ByteBuffer invalidUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90, (byte)0x8D});
FutureCallback callback = new FutureCallback();
messageSink.accept(new Frame(OpCode.TEXT, invalidUtf8Payload).setFin(true), callback);
// Callback should fail and we don't receive the message in the sink.
RuntimeException error = assertThrows(RuntimeException.class, () -> callback.block(5, TimeUnit.SECONDS));
assertThat(error.getCause(), instanceOf(Utf8Appendable.NotUtf8Exception.class));
List<String> message = Objects.requireNonNull(endpoint.messages.poll(5, TimeUnit.SECONDS));
assertTrue(message.isEmpty());
}
@Test
public void testInvalidMultiFrameUtf8() throws Exception
{
ByteBuffer firstUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90});
ByteBuffer continuationUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0x8D});
FutureCallback firstCallback = new FutureCallback();
messageSink.accept(new Frame(OpCode.TEXT, firstUtf8Payload).setFin(false), firstCallback);
firstCallback.block(5, TimeUnit.SECONDS);
FutureCallback continuationCallback = new FutureCallback();
messageSink.accept(new Frame(OpCode.TEXT, continuationUtf8Payload).setFin(true), continuationCallback);
// Callback should fail and we only received the first frame which had no full character.
RuntimeException error = assertThrows(RuntimeException.class, () -> continuationCallback.block(5, TimeUnit.SECONDS));
assertThat(error.getCause(), instanceOf(Utf8Appendable.NotUtf8Exception.class));
List<String> message = Objects.requireNonNull(endpoint.messages.poll(5, TimeUnit.SECONDS));
assertThat(message.size(), is(1));
assertThat(message.get(0), is(""));
}
public static class OnMessageEndpoint
{
private BlockingArrayQueue<List<String>> messages;
public OnMessageEndpoint()
{
messages = new BlockingArrayQueue<>();
messages.add(new ArrayList<>());
}
public void onMessage(String message, boolean last)
{
messages.get(messages.size() - 1).add(message);
if (last)
messages.add(new ArrayList<>());
}
public MethodHandle getMethodHandle() throws Exception
{
return MethodHandles.lookup()
.findVirtual(this.getClass(), "onMessage", MethodType.methodType(void.class, String.class, boolean.class))
.bindTo(this);
}
}
}

View File

@ -0,0 +1,147 @@
//
// ========================================================================
// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under
// the terms of the Eclipse Public License 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0
//
// This Source Code may also be made available under the following
// Secondary Licenses when the conditions for such availability set
// forth in the Eclipse Public License, v. 2.0 are satisfied:
// the Apache License v2.0 which is available at
// https://www.apache.org/licenses/LICENSE-2.0
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.websocket.util;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.nio.ByteBuffer;
import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.util.BlockingArrayQueue;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.FutureCallback;
import org.eclipse.jetty.util.Utf8Appendable;
import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.Frame;
import org.eclipse.jetty.websocket.core.OpCode;
import org.eclipse.jetty.websocket.core.exception.MessageTooLargeException;
import org.eclipse.jetty.websocket.util.messages.StringMessageSink;
import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class StringMessageSinkTest
{
private CoreSession coreSession = new CoreSession.Empty();
private OnMessageEndpoint endpoint = new OnMessageEndpoint();
@Test
public void testMaxMessageSize() throws Exception
{
StringMessageSink messageSink = new StringMessageSink(coreSession, endpoint.getMethodHandle());
ByteBuffer utf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90, (byte)0x8D, (byte)0x88});
FutureCallback callback = new FutureCallback();
coreSession.setMaxTextMessageSize(3);
messageSink.accept(new Frame(OpCode.TEXT, utf8Payload).setFin(true), callback);
// Callback should fail and we don't receive the message in the sink.
RuntimeException error = assertThrows(RuntimeException.class, () -> callback.block(5, TimeUnit.SECONDS));
assertThat(error.getCause(), instanceOf(MessageTooLargeException.class));
assertNull(endpoint.messages.poll());
}
@Test
public void testValidUtf8() throws Exception
{
StringMessageSink messageSink = new StringMessageSink(coreSession, endpoint.getMethodHandle());
ByteBuffer utf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90, (byte)0x8D, (byte)0x88});
FutureCallback callback = new FutureCallback();
messageSink.accept(new Frame(OpCode.TEXT, utf8Payload).setFin(true), callback);
callback.block(5, TimeUnit.SECONDS);
assertThat(endpoint.messages.poll(5, TimeUnit.SECONDS), is("\uD800\uDF48"));
}
@Test
public void testUtf8Continuation() throws Exception
{
StringMessageSink messageSink = new StringMessageSink(coreSession, endpoint.getMethodHandle());
ByteBuffer firstUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90});
ByteBuffer continuationUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0x8D, (byte)0x88});
FutureCallback callback = new FutureCallback();
messageSink.accept(new Frame(OpCode.TEXT, firstUtf8Payload).setFin(false), callback);
callback.block(5, TimeUnit.SECONDS);
callback = new FutureCallback();
messageSink.accept(new Frame(OpCode.TEXT, continuationUtf8Payload).setFin(true), callback);
callback.block(5, TimeUnit.SECONDS);
assertThat(endpoint.messages.poll(5, TimeUnit.SECONDS), is("\uD800\uDF48"));
}
@Test
public void testInvalidSingleFrameUtf8() throws Exception
{
StringMessageSink messageSink = new StringMessageSink(coreSession, endpoint.getMethodHandle());
ByteBuffer invalidUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90, (byte)0x8D});
FutureCallback callback = new FutureCallback();
messageSink.accept(new Frame(OpCode.TEXT, invalidUtf8Payload).setFin(true), callback);
// Callback should fail and we don't receive the message in the sink.
RuntimeException error = assertThrows(RuntimeException.class, () -> callback.block(5, TimeUnit.SECONDS));
assertThat(error.getCause(), instanceOf(Utf8Appendable.NotUtf8Exception.class));
assertNull(endpoint.messages.poll());
}
@Test
public void testInvalidMultiFrameUtf8() throws Exception
{
StringMessageSink messageSink = new StringMessageSink(coreSession, endpoint.getMethodHandle());
ByteBuffer firstUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90});
ByteBuffer continuationUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0x8D});
FutureCallback firstCallback = new FutureCallback();
messageSink.accept(new Frame(OpCode.TEXT, firstUtf8Payload).setFin(false), firstCallback);
firstCallback.block(5, TimeUnit.SECONDS);
FutureCallback continuationCallback = new FutureCallback();
messageSink.accept(new Frame(OpCode.TEXT, continuationUtf8Payload).setFin(true), continuationCallback);
// Callback should fail and we don't receive the message in the sink.
RuntimeException error = assertThrows(RuntimeException.class, () -> continuationCallback.block(5, TimeUnit.SECONDS));
assertThat(error.getCause(), instanceOf(Utf8Appendable.NotUtf8Exception.class));
assertNull(endpoint.messages.poll());
}
public static class OnMessageEndpoint
{
private BlockingArrayQueue<String> messages = new BlockingArrayQueue<>();
public void onMessage(String message)
{
messages.add(message);
}
public MethodHandle getMethodHandle() throws Exception
{
return MethodHandles.lookup()
.findVirtual(this.getClass(), "onMessage", MethodType.methodType(void.class, String.class))
.bindTo(this);
}
}
}