Merge pull request #4593 from eclipse/jetty-10.0.x-4571-MessageSink

Issue #4571 - websocket aggregating text and binary MessageSinks
This commit is contained in:
Lachlan 2020-03-11 14:14:15 +11:00 committed by GitHub
commit b0ddba49da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 403 additions and 164 deletions

View File

@ -221,6 +221,12 @@ public class Frame
return payload; return payload;
} }
/**
* Get the payload of the frame as a UTF-8 string.
* <p>Should only be used in testing, does not validate the
* UTF-8 and a non fin frame can contain partial UTF-8 characters.</p>
* @return the payload as a UTF-8 string.
*/
public String getPayloadAsUTF8() public String getPayloadAsUTF8()
{ {
if (payload == null) if (payload == null)

View File

@ -477,7 +477,7 @@ public class MessageReceivingTest
@Override @Override
public void onMessage(ByteBuffer message) public void onMessage(ByteBuffer message)
{ {
final String stringResult = new String(message.array()); final String stringResult = BufferUtil.toString(message);
messageQueue.offer(stringResult); messageQueue.offer(stringResult);
} }
} }

View File

@ -19,6 +19,7 @@
package org.eclipse.jetty.websocket.util.messages; package org.eclipse.jetty.websocket.util.messages;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodType; import java.lang.invoke.MethodType;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -49,40 +50,40 @@ public class ByteArrayMessageSink extends AbstractMessageSink
} }
} }
@SuppressWarnings("Duplicates")
@Override @Override
public void accept(Frame frame, Callback callback) public void accept(Frame frame, Callback callback)
{ {
try try
{ {
if (frame.hasPayload()) size += frame.getPayloadLength();
{
ByteBuffer payload = frame.getPayload();
size += payload.remaining();
long maxBinaryMessageSize = session.getMaxBinaryMessageSize(); long maxBinaryMessageSize = session.getMaxBinaryMessageSize();
if (maxBinaryMessageSize > 0 && size > maxBinaryMessageSize) if (maxBinaryMessageSize > 0 && size > maxBinaryMessageSize)
{ {
throw new MessageTooLargeException(String.format("Binary message too large: (actual) %,d > (configured max binary buffer size) %,d", throw new MessageTooLargeException(String.format("Binary message too large: (actual) %,d > (configured max binary message size) %,d",
size, maxBinaryMessageSize)); size, maxBinaryMessageSize));
} }
if (out == null) // If we are fin and no OutputStream has been created we don't need to aggregate.
out = new ByteArrayOutputStream(BUFFER_SIZE); if (frame.isFin() && (out == null))
BufferUtil.writeTo(payload, out);
}
if (frame.isFin())
{ {
if (out != null) if (frame.hasPayload())
{ {
byte[] buf = out.toByteArray(); byte[] buf = BufferUtil.toArray(frame.getPayload());
methodHandle.invoke(buf, 0, buf.length); methodHandle.invoke(buf, 0, buf.length);
} }
else else
methodHandle.invoke(EMPTY_BUFFER, 0, 0); methodHandle.invoke(EMPTY_BUFFER, 0, 0);
callback.succeeded();
return;
} }
aggregatePayload(frame);
if (frame.isFin())
{
byte[] buf = out.toByteArray();
methodHandle.invoke(buf, 0, buf.length);
}
callback.succeeded(); callback.succeeded();
} }
catch (Throwable t) catch (Throwable t)
@ -99,4 +100,15 @@ public class ByteArrayMessageSink extends AbstractMessageSink
} }
} }
} }
private void aggregatePayload(Frame frame) throws IOException
{
if (frame.hasPayload())
{
ByteBuffer payload = frame.getPayload();
if (out == null)
out = new ByteArrayOutputStream(BUFFER_SIZE);
BufferUtil.writeTo(payload, out);
}
}
} }

View File

@ -19,6 +19,7 @@
package org.eclipse.jetty.websocket.util.messages; package org.eclipse.jetty.websocket.util.messages;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodType; import java.lang.invoke.MethodType;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -50,16 +51,12 @@ public class ByteBufferMessageSink extends AbstractMessageSink
} }
} }
@SuppressWarnings("Duplicates")
@Override @Override
public void accept(Frame frame, Callback callback) public void accept(Frame frame, Callback callback)
{ {
try try
{ {
if (frame.hasPayload()) size += frame.getPayloadLength();
{
ByteBuffer payload = frame.getPayload();
size += payload.remaining();
long maxBinaryMessageSize = session.getMaxBinaryMessageSize(); long maxBinaryMessageSize = session.getMaxBinaryMessageSize();
if (maxBinaryMessageSize > 0 && size > maxBinaryMessageSize) if (maxBinaryMessageSize > 0 && size > maxBinaryMessageSize)
{ {
@ -67,21 +64,22 @@ public class ByteBufferMessageSink extends AbstractMessageSink
size, maxBinaryMessageSize)); size, maxBinaryMessageSize));
} }
if (out == null) // If we are fin and no OutputStream has been created we don't need to aggregate.
out = new ByteArrayOutputStream(BUFFER_SIZE); if (frame.isFin() && (out == null))
BufferUtil.writeTo(payload, out);
payload.position(payload.limit()); // consume buffer
}
if (frame.isFin())
{ {
if (out != null) if (frame.hasPayload())
methodHandle.invoke(ByteBuffer.wrap(out.toByteArray())); methodHandle.invoke(frame.getPayload());
else else
methodHandle.invoke(BufferUtil.EMPTY_BUFFER); methodHandle.invoke(BufferUtil.EMPTY_BUFFER);
callback.succeeded();
return;
} }
aggregatePayload(frame);
if (frame.isFin())
methodHandle.invoke(ByteBuffer.wrap(out.toByteArray()));
callback.succeeded(); callback.succeeded();
} }
catch (Throwable t) catch (Throwable t)
@ -98,4 +96,18 @@ public class ByteBufferMessageSink extends AbstractMessageSink
} }
} }
} }
private void aggregatePayload(Frame frame) throws IOException
{
if (frame.hasPayload())
{
ByteBuffer payload = frame.getPayload();
if (out == null)
out = new ByteArrayOutputStream(BUFFER_SIZE);
BufferUtil.writeTo(payload, out);
payload.position(payload.limit()); // consume buffer
}
}
} }

View File

@ -136,20 +136,7 @@ public abstract class DispatchedMessageSink<T> extends AbstractMessageSink
if (frame.isFin()) if (frame.isFin())
{ {
CompletableFuture<Void> finComplete = new CompletableFuture<>(); CompletableFuture<Void> finComplete = new CompletableFuture<>();
frameCallback = new Callback() frameCallback = Callback.from(() -> finComplete.complete(null), finComplete::completeExceptionally);
{
@Override
public void failed(Throwable cause)
{
finComplete.completeExceptionally(cause);
}
@Override
public void succeeded()
{
finComplete.complete(null);
}
};
CompletableFuture.allOf(dispatchComplete, finComplete).whenComplete( CompletableFuture.allOf(dispatchComplete, finComplete).whenComplete(
(aVoid, throwable) -> (aVoid, throwable) ->
{ {

View File

@ -20,7 +20,6 @@ package org.eclipse.jetty.websocket.util.messages;
import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodType; import java.lang.invoke.MethodType;
import java.nio.ByteBuffer;
import org.eclipse.jetty.util.BufferUtil; import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback; import org.eclipse.jetty.util.Callback;
@ -30,6 +29,8 @@ import org.eclipse.jetty.websocket.util.InvalidSignatureException;
public class PartialByteArrayMessageSink extends AbstractMessageSink public class PartialByteArrayMessageSink extends AbstractMessageSink
{ {
private static byte[] EMPTY_BUFFER = new byte[0];
public PartialByteArrayMessageSink(CoreSession session, MethodHandle methodHandle) public PartialByteArrayMessageSink(CoreSession session, MethodHandle methodHandle)
{ {
super(session, methodHandle); super(session, methodHandle);
@ -42,28 +43,17 @@ public class PartialByteArrayMessageSink extends AbstractMessageSink
} }
} }
@SuppressWarnings("Duplicates")
@Override @Override
public void accept(Frame frame, Callback callback) public void accept(Frame frame, Callback callback)
{ {
try try
{ {
byte[] buffer; if (frame.hasPayload() || frame.isFin())
int offset = 0;
int length = 0;
if (frame.hasPayload())
{ {
ByteBuffer payload = frame.getPayload(); byte[] buffer = frame.hasPayload() ? BufferUtil.toArray(frame.getPayload()) : EMPTY_BUFFER;
length = payload.remaining(); methodHandle.invoke(buffer, 0, buffer.length, frame.isFin());
buffer = BufferUtil.toArray(payload);
}
else
{
buffer = new byte[0];
} }
methodHandle.invoke(buffer, offset, length, frame.isFin());
callback.succeeded(); callback.succeeded();
} }
catch (Throwable t) catch (Throwable t)

View File

@ -19,9 +19,7 @@
package org.eclipse.jetty.websocket.util.messages; package org.eclipse.jetty.websocket.util.messages;
import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandle;
import java.nio.ByteBuffer;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback; import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.websocket.core.CoreSession; import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.Frame; import org.eclipse.jetty.websocket.core.Frame;
@ -41,29 +39,13 @@ public class PartialByteBufferMessageSink extends AbstractMessageSink
*/ */
} }
@SuppressWarnings("Duplicates")
@Override @Override
public void accept(Frame frame, Callback callback) public void accept(Frame frame, Callback callback)
{ {
try try
{ {
ByteBuffer buffer; if (frame.hasPayload() || frame.isFin())
methodHandle.invoke(frame.getPayload(), frame.isFin());
if (frame.hasPayload())
{
ByteBuffer payload = frame.getPayload();
// copy buffer here
buffer = ByteBuffer.allocate(payload.remaining());
BufferUtil.clearToFill(buffer);
BufferUtil.put(payload, buffer);
BufferUtil.flipToFlush(buffer, 0);
}
else
{
buffer = BufferUtil.EMPTY_BUFFER;
}
methodHandle.invoke(buffer, frame.isFin());
callback.succeeded(); callback.succeeded();
} }

View File

@ -19,73 +19,40 @@
package org.eclipse.jetty.websocket.util.messages; package org.eclipse.jetty.websocket.util.messages;
import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandle;
import java.nio.ByteBuffer;
import java.util.Objects; import java.util.Objects;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback; import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.Utf8StringBuilder; import org.eclipse.jetty.util.Utf8StringBuilder;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.websocket.core.CoreSession; import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.Frame; import org.eclipse.jetty.websocket.core.Frame;
public class PartialStringMessageSink extends AbstractMessageSink public class PartialStringMessageSink extends AbstractMessageSink
{ {
private static final Logger LOG = Log.getLogger(PartialStringMessageSink.class); private Utf8StringBuilder out;
private Utf8StringBuilder utf;
private int size;
public PartialStringMessageSink(CoreSession session, MethodHandle methodHandle) public PartialStringMessageSink(CoreSession session, MethodHandle methodHandle)
{ {
super(session, methodHandle); super(session, methodHandle);
Objects.requireNonNull(methodHandle, "MethodHandle"); Objects.requireNonNull(methodHandle, "MethodHandle");
this.size = 0;
} }
@SuppressWarnings("Duplicates")
@Override @Override
public void accept(Frame frame, Callback callback) public void accept(Frame frame, Callback callback)
{ {
try try
{ {
if (utf == null) if (out == null)
utf = new Utf8StringBuilder(1024); out = new Utf8StringBuilder(session.getInputBufferSize());
if (frame.hasPayload())
{
ByteBuffer payload = frame.getPayload();
//TODO we should fragment on maxTextMessageBufferSize not limit
//TODO also for PartialBinaryMessageSink
/*
if ((session.getMaxTextMessageBufferSize() > 0) && (size + payload.remaining() > session.getMaxTextMessageBufferSize()))
{
throw new MessageTooLargeException(String.format("Binary message too large: (actual) %,d > (configured max text buffer size) %,d",
size + payload.remaining(), session.getMaxTextMessageBufferSize()));
}
*/
size += payload.remaining();
if (LOG.isDebugEnabled())
LOG.debug("Raw Payload {}", BufferUtil.toDetailString(payload));
// allow for fast fail of BAD utf
utf.append(payload);
}
out.append(frame.getPayload());
if (frame.isFin()) if (frame.isFin())
{ {
// Using toString to trigger failure on incomplete UTF-8 methodHandle.invoke(out.toString(), true);
methodHandle.invoke(utf.toString(), true); out = null;
// reset
size = 0;
utf = null;
} }
else else
{ {
methodHandle.invoke(utf.takePartialString(), false); methodHandle.invoke(out.takePartialString(), false);
} }
callback.succeeded(); callback.succeeded();

View File

@ -19,21 +19,16 @@
package org.eclipse.jetty.websocket.util.messages; package org.eclipse.jetty.websocket.util.messages;
import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandle;
import java.nio.ByteBuffer;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback; import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.Utf8StringBuilder; import org.eclipse.jetty.util.Utf8StringBuilder;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.websocket.core.CoreSession; import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.Frame; import org.eclipse.jetty.websocket.core.Frame;
import org.eclipse.jetty.websocket.core.exception.MessageTooLargeException; import org.eclipse.jetty.websocket.core.exception.MessageTooLargeException;
public class StringMessageSink extends AbstractMessageSink public class StringMessageSink extends AbstractMessageSink
{ {
private static final Logger LOG = Log.getLogger(StringMessageSink.class); private Utf8StringBuilder out;
private Utf8StringBuilder utf;
private int size; private int size;
public StringMessageSink(CoreSession session, MethodHandle methodHandle) public StringMessageSink(CoreSession session, MethodHandle methodHandle)
@ -42,17 +37,12 @@ public class StringMessageSink extends AbstractMessageSink
this.size = 0; this.size = 0;
} }
@SuppressWarnings("Duplicates")
@Override @Override
public void accept(Frame frame, Callback callback) public void accept(Frame frame, Callback callback)
{ {
try try
{ {
if (frame.hasPayload()) size += frame.getPayloadLength();
{
ByteBuffer payload = frame.getPayload();
size += payload.remaining();
long maxTextMessageSize = session.getMaxTextMessageSize(); long maxTextMessageSize = session.getMaxTextMessageSize();
if (maxTextMessageSize > 0 && size > maxTextMessageSize) if (maxTextMessageSize > 0 && size > maxTextMessageSize)
{ {
@ -60,28 +50,12 @@ public class StringMessageSink extends AbstractMessageSink
size, maxTextMessageSize)); size, maxTextMessageSize));
} }
if (utf == null) if (out == null)
utf = new Utf8StringBuilder(1024); out = new Utf8StringBuilder(session.getInputBufferSize());
if (LOG.isDebugEnabled())
LOG.debug("Raw Payload {}", BufferUtil.toDetailString(payload));
// allow for fast fail of BAD utf (incomplete utf will trigger on messageComplete)
utf.append(payload);
}
out.append(frame.getPayload());
if (frame.isFin()) if (frame.isFin())
{ methodHandle.invoke(out.toString());
// notify event
if (utf != null)
methodHandle.invoke(utf.toString());
else
methodHandle.invoke("");
// reset
size = 0;
utf = null;
}
callback.succeeded(); callback.succeeded();
} }
@ -89,5 +63,14 @@ public class StringMessageSink extends AbstractMessageSink
{ {
callback.failed(t); callback.failed(t);
} }
finally
{
if (frame.isFin())
{
// reset
size = 0;
out = null;
}
}
} }
} }

View File

@ -0,0 +1,153 @@
//
// ========================================================================
// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under
// the terms of the Eclipse Public License 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0
//
// This Source Code may also be made available under the following
// Secondary Licenses when the conditions for such availability set
// forth in the Eclipse Public License, v. 2.0 are satisfied:
// the Apache License v2.0 which is available at
// https://www.apache.org/licenses/LICENSE-2.0
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.websocket.util;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.util.BlockingArrayQueue;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.FutureCallback;
import org.eclipse.jetty.util.Utf8Appendable;
import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.Frame;
import org.eclipse.jetty.websocket.core.OpCode;
import org.eclipse.jetty.websocket.util.messages.PartialStringMessageSink;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class PartialStringMessageSinkTest
{
private CoreSession coreSession = new CoreSession.Empty();
private OnMessageEndpoint endpoint = new OnMessageEndpoint();
private PartialStringMessageSink messageSink;
@BeforeEach
public void before() throws Exception
{
messageSink = new PartialStringMessageSink(coreSession, endpoint.getMethodHandle());
}
@Test
public void testValidUtf8() throws Exception
{
ByteBuffer utf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90, (byte)0x8D, (byte)0x88});
FutureCallback callback = new FutureCallback();
messageSink.accept(new Frame(OpCode.TEXT, utf8Payload).setFin(true), callback);
callback.block(5, TimeUnit.SECONDS);
List<String> message = Objects.requireNonNull(endpoint.messages.poll(5, TimeUnit.SECONDS));
assertThat(message.size(), is(1));
assertThat(message.get(0), is("\uD800\uDF48"));
}
@Test
public void testUtf8Continuation() throws Exception
{
ByteBuffer firstUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90});
ByteBuffer continuationUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0x8D, (byte)0x88});
FutureCallback callback = new FutureCallback();
messageSink.accept(new Frame(OpCode.TEXT, firstUtf8Payload).setFin(false), callback);
callback.block(5, TimeUnit.SECONDS);
callback = new FutureCallback();
messageSink.accept(new Frame(OpCode.TEXT, continuationUtf8Payload).setFin(true), callback);
callback.block(5, TimeUnit.SECONDS);
List<String> message = Objects.requireNonNull(endpoint.messages.poll(5, TimeUnit.SECONDS));
assertThat(message.size(), is(2));
assertThat(message.get(0), is(""));
assertThat(message.get(1), is("\uD800\uDF48"));
}
@Test
public void testInvalidSingleFrameUtf8() throws Exception
{
ByteBuffer invalidUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90, (byte)0x8D});
FutureCallback callback = new FutureCallback();
messageSink.accept(new Frame(OpCode.TEXT, invalidUtf8Payload).setFin(true), callback);
// Callback should fail and we don't receive the message in the sink.
RuntimeException error = assertThrows(RuntimeException.class, () -> callback.block(5, TimeUnit.SECONDS));
assertThat(error.getCause(), instanceOf(Utf8Appendable.NotUtf8Exception.class));
List<String> message = Objects.requireNonNull(endpoint.messages.poll(5, TimeUnit.SECONDS));
assertTrue(message.isEmpty());
}
@Test
public void testInvalidMultiFrameUtf8() throws Exception
{
ByteBuffer firstUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90});
ByteBuffer continuationUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0x8D});
FutureCallback firstCallback = new FutureCallback();
messageSink.accept(new Frame(OpCode.TEXT, firstUtf8Payload).setFin(false), firstCallback);
firstCallback.block(5, TimeUnit.SECONDS);
FutureCallback continuationCallback = new FutureCallback();
messageSink.accept(new Frame(OpCode.TEXT, continuationUtf8Payload).setFin(true), continuationCallback);
// Callback should fail and we only received the first frame which had no full character.
RuntimeException error = assertThrows(RuntimeException.class, () -> continuationCallback.block(5, TimeUnit.SECONDS));
assertThat(error.getCause(), instanceOf(Utf8Appendable.NotUtf8Exception.class));
List<String> message = Objects.requireNonNull(endpoint.messages.poll(5, TimeUnit.SECONDS));
assertThat(message.size(), is(1));
assertThat(message.get(0), is(""));
}
public static class OnMessageEndpoint
{
private BlockingArrayQueue<List<String>> messages;
public OnMessageEndpoint()
{
messages = new BlockingArrayQueue<>();
messages.add(new ArrayList<>());
}
public void onMessage(String message, boolean last)
{
messages.get(messages.size() - 1).add(message);
if (last)
messages.add(new ArrayList<>());
}
public MethodHandle getMethodHandle() throws Exception
{
return MethodHandles.lookup()
.findVirtual(this.getClass(), "onMessage", MethodType.methodType(void.class, String.class, boolean.class))
.bindTo(this);
}
}
}

View File

@ -0,0 +1,147 @@
//
// ========================================================================
// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under
// the terms of the Eclipse Public License 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0
//
// This Source Code may also be made available under the following
// Secondary Licenses when the conditions for such availability set
// forth in the Eclipse Public License, v. 2.0 are satisfied:
// the Apache License v2.0 which is available at
// https://www.apache.org/licenses/LICENSE-2.0
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.websocket.util;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.nio.ByteBuffer;
import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.util.BlockingArrayQueue;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.FutureCallback;
import org.eclipse.jetty.util.Utf8Appendable;
import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.Frame;
import org.eclipse.jetty.websocket.core.OpCode;
import org.eclipse.jetty.websocket.core.exception.MessageTooLargeException;
import org.eclipse.jetty.websocket.util.messages.StringMessageSink;
import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class StringMessageSinkTest
{
private CoreSession coreSession = new CoreSession.Empty();
private OnMessageEndpoint endpoint = new OnMessageEndpoint();
@Test
public void testMaxMessageSize() throws Exception
{
StringMessageSink messageSink = new StringMessageSink(coreSession, endpoint.getMethodHandle());
ByteBuffer utf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90, (byte)0x8D, (byte)0x88});
FutureCallback callback = new FutureCallback();
coreSession.setMaxTextMessageSize(3);
messageSink.accept(new Frame(OpCode.TEXT, utf8Payload).setFin(true), callback);
// Callback should fail and we don't receive the message in the sink.
RuntimeException error = assertThrows(RuntimeException.class, () -> callback.block(5, TimeUnit.SECONDS));
assertThat(error.getCause(), instanceOf(MessageTooLargeException.class));
assertNull(endpoint.messages.poll());
}
@Test
public void testValidUtf8() throws Exception
{
StringMessageSink messageSink = new StringMessageSink(coreSession, endpoint.getMethodHandle());
ByteBuffer utf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90, (byte)0x8D, (byte)0x88});
FutureCallback callback = new FutureCallback();
messageSink.accept(new Frame(OpCode.TEXT, utf8Payload).setFin(true), callback);
callback.block(5, TimeUnit.SECONDS);
assertThat(endpoint.messages.poll(5, TimeUnit.SECONDS), is("\uD800\uDF48"));
}
@Test
public void testUtf8Continuation() throws Exception
{
StringMessageSink messageSink = new StringMessageSink(coreSession, endpoint.getMethodHandle());
ByteBuffer firstUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90});
ByteBuffer continuationUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0x8D, (byte)0x88});
FutureCallback callback = new FutureCallback();
messageSink.accept(new Frame(OpCode.TEXT, firstUtf8Payload).setFin(false), callback);
callback.block(5, TimeUnit.SECONDS);
callback = new FutureCallback();
messageSink.accept(new Frame(OpCode.TEXT, continuationUtf8Payload).setFin(true), callback);
callback.block(5, TimeUnit.SECONDS);
assertThat(endpoint.messages.poll(5, TimeUnit.SECONDS), is("\uD800\uDF48"));
}
@Test
public void testInvalidSingleFrameUtf8() throws Exception
{
StringMessageSink messageSink = new StringMessageSink(coreSession, endpoint.getMethodHandle());
ByteBuffer invalidUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90, (byte)0x8D});
FutureCallback callback = new FutureCallback();
messageSink.accept(new Frame(OpCode.TEXT, invalidUtf8Payload).setFin(true), callback);
// Callback should fail and we don't receive the message in the sink.
RuntimeException error = assertThrows(RuntimeException.class, () -> callback.block(5, TimeUnit.SECONDS));
assertThat(error.getCause(), instanceOf(Utf8Appendable.NotUtf8Exception.class));
assertNull(endpoint.messages.poll());
}
@Test
public void testInvalidMultiFrameUtf8() throws Exception
{
StringMessageSink messageSink = new StringMessageSink(coreSession, endpoint.getMethodHandle());
ByteBuffer firstUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90});
ByteBuffer continuationUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0x8D});
FutureCallback firstCallback = new FutureCallback();
messageSink.accept(new Frame(OpCode.TEXT, firstUtf8Payload).setFin(false), firstCallback);
firstCallback.block(5, TimeUnit.SECONDS);
FutureCallback continuationCallback = new FutureCallback();
messageSink.accept(new Frame(OpCode.TEXT, continuationUtf8Payload).setFin(true), continuationCallback);
// Callback should fail and we don't receive the message in the sink.
RuntimeException error = assertThrows(RuntimeException.class, () -> continuationCallback.block(5, TimeUnit.SECONDS));
assertThat(error.getCause(), instanceOf(Utf8Appendable.NotUtf8Exception.class));
assertNull(endpoint.messages.poll());
}
public static class OnMessageEndpoint
{
private BlockingArrayQueue<String> messages = new BlockingArrayQueue<>();
public void onMessage(String message)
{
messages.add(message);
}
public MethodHandle getMethodHandle() throws Exception
{
return MethodHandles.lookup()
.findVirtual(this.getClass(), "onMessage", MethodType.methodType(void.class, String.class))
.bindTo(this);
}
}
}