From 36cccd2c885540cabb2f300a86fb0e453ed33037 Mon Sep 17 00:00:00 2001 From: Lachlan Roberts Date: Fri, 14 Feb 2020 12:26:59 +1100 Subject: [PATCH 1/4] Issue #4571 - optimise aggregating text and binary MessageSinks Signed-off-by: Lachlan Roberts --- .../jetty/io/ByteBufferPoolOutputStream.java | 111 ++++++++++++++++++ .../util/messages/ByteArrayMessageSink.java | 55 +++++---- .../util/messages/ByteBufferMessageSink.java | 54 +++++---- .../util/messages/StringMessageSink.java | 79 +++++++------ 4 files changed, 223 insertions(+), 76 deletions(-) create mode 100644 jetty-io/src/main/java/org/eclipse/jetty/io/ByteBufferPoolOutputStream.java diff --git a/jetty-io/src/main/java/org/eclipse/jetty/io/ByteBufferPoolOutputStream.java b/jetty-io/src/main/java/org/eclipse/jetty/io/ByteBufferPoolOutputStream.java new file mode 100644 index 00000000000..3dfa67791c6 --- /dev/null +++ b/jetty-io/src/main/java/org/eclipse/jetty/io/ByteBufferPoolOutputStream.java @@ -0,0 +1,111 @@ +// +// ======================================================================== +// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others. +// +// This program and the accompanying materials are made available under +// the terms of the Eclipse Public License 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0 +// +// This Source Code may also be made available under the following +// Secondary Licenses when the conditions for such availability set +// forth in the Eclipse Public License, v. 2.0 are satisfied: +// the Apache License v2.0 which is available at +// https://www.apache.org/licenses/LICENSE-2.0 +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +// ======================================================================== +// + +package org.eclipse.jetty.io; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Objects; + +import org.eclipse.jetty.util.BufferUtil; + +public class ByteBufferPoolOutputStream extends OutputStream +{ + private final ByteBufferPool bufferPool; + private final ArrayList buffers; + private final boolean direct; + private final int acquireSize; + + private ByteBuffer aggregateBuffer; + private int size = 0; + + public ByteBufferPoolOutputStream(ByteBufferPool bufferPool, int acquireSize, boolean direct) + { + this.buffers = new ArrayList<>(); + this.direct = direct; + this.bufferPool = Objects.requireNonNull(bufferPool); + this.acquireSize = acquireSize; + if (acquireSize <= 0) + throw new IllegalArgumentException(); + + this.buffers.add(bufferPool.acquire(acquireSize, direct)); + } + + @Override + public void write(int b) throws IOException + { + write(new byte[]{(byte)b}, 0, 1); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException + { + write(ByteBuffer.wrap(b, off, len)); + } + + public void write(ByteBuffer data) + { + while (data.hasRemaining()) + { + ByteBuffer buffer = buffers.get(buffers.size() - 1); + size += BufferUtil.append(buffer, data); + if (!buffer.hasRemaining()) + buffers.add(bufferPool.acquire(acquireSize, direct)); + } + } + + public int size() + { + return size; + } + + public ByteBuffer toByteBuffer() + { + releaseAggregate(); + aggregateBuffer = bufferPool.acquire(size, direct); + for (ByteBuffer data : buffers) + { + BufferUtil.append(aggregateBuffer, data); + } + return aggregateBuffer; + } + + public byte[] toByteArray() + { + return BufferUtil.toArray(toByteBuffer()); + } + + private void releaseAggregate() + { + if (aggregateBuffer != null) + { + bufferPool.release(aggregateBuffer); + aggregateBuffer = null; + } + } + + @Override + public void close() + { + releaseAggregate(); + for (ByteBuffer buffer : buffers) + bufferPool.release(buffer); + } +} diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ByteArrayMessageSink.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ByteArrayMessageSink.java index 831b33ffc15..4763ee16946 100644 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ByteArrayMessageSink.java +++ b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ByteArrayMessageSink.java @@ -19,6 +19,7 @@ package org.eclipse.jetty.websocket.util.messages; import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodType; import java.nio.ByteBuffer; @@ -55,34 +56,28 @@ public class ByteArrayMessageSink extends AbstractMessageSink { try { - if (frame.hasPayload()) + // If we are fin and no OutputStream has been created we don't need to aggregate. + if (frame.isFin() && (out == null)) { - ByteBuffer payload = frame.getPayload(); - size += payload.remaining(); - long maxBinaryMessageSize = session.getMaxBinaryMessageSize(); - if (maxBinaryMessageSize > 0 && size > maxBinaryMessageSize) + if (frame.hasPayload()) { - throw new MessageTooLargeException(String.format("Binary message too large: (actual) %,d > (configured max binary buffer size) %,d", - size, maxBinaryMessageSize)); - } - - if (out == null) - out = new ByteArrayOutputStream(BUFFER_SIZE); - - BufferUtil.writeTo(payload, out); - } - - if (frame.isFin()) - { - if (out != null) - { - byte[] buf = out.toByteArray(); + byte[] buf = BufferUtil.toArray(frame.getPayload()); methodHandle.invoke(buf, 0, buf.length); } else methodHandle.invoke(EMPTY_BUFFER, 0, 0); + + callback.succeeded(); + return; } + + aggregatePayload(frame); + if (frame.isFin()) + { + byte[] buf = out.toByteArray(); + methodHandle.invoke(buf, 0, buf.length); + } callback.succeeded(); } catch (Throwable t) @@ -99,4 +94,24 @@ public class ByteArrayMessageSink extends AbstractMessageSink } } } + + private void aggregatePayload(Frame frame) throws IOException + { + if (frame.hasPayload()) + { + ByteBuffer payload = frame.getPayload(); + size += payload.remaining(); + long maxBinaryMessageSize = session.getMaxBinaryMessageSize(); + if (maxBinaryMessageSize > 0 && size > maxBinaryMessageSize) + { + throw new MessageTooLargeException(String.format("Binary message too large: (actual) %,d > (configured max binary message size) %,d", + size, maxBinaryMessageSize)); + } + + if (out == null) + out = new ByteArrayOutputStream(BUFFER_SIZE); + + BufferUtil.writeTo(payload, out); + } + } } diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ByteBufferMessageSink.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ByteBufferMessageSink.java index 5f86a66e607..1b4d6094701 100644 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ByteBufferMessageSink.java +++ b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ByteBufferMessageSink.java @@ -19,6 +19,7 @@ package org.eclipse.jetty.websocket.util.messages; import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodType; import java.nio.ByteBuffer; @@ -56,32 +57,22 @@ public class ByteBufferMessageSink extends AbstractMessageSink { try { - if (frame.hasPayload()) + // If we are fin and no OutputStream has been created we don't need to aggregate. + if (frame.isFin() && (out == null)) { - ByteBuffer payload = frame.getPayload(); - size += payload.remaining(); - long maxBinaryMessageSize = session.getMaxBinaryMessageSize(); - if (maxBinaryMessageSize > 0 && size > maxBinaryMessageSize) - { - throw new MessageTooLargeException(String.format("Binary message too large: (actual) %,d > (configured max binary message size) %,d", - size, maxBinaryMessageSize)); - } - - if (out == null) - out = new ByteArrayOutputStream(BUFFER_SIZE); - - BufferUtil.writeTo(payload, out); - payload.position(payload.limit()); // consume buffer - } - - if (frame.isFin()) - { - if (out != null) - methodHandle.invoke(ByteBuffer.wrap(out.toByteArray())); + if (frame.hasPayload()) + methodHandle.invoke(frame.getPayload()); else methodHandle.invoke(BufferUtil.EMPTY_BUFFER); + + callback.succeeded(); + return; } + aggregatePayload(frame); + if (frame.isFin()) + methodHandle.invoke(ByteBuffer.wrap(out.toByteArray())); + callback.succeeded(); } catch (Throwable t) @@ -98,4 +89,25 @@ public class ByteBufferMessageSink extends AbstractMessageSink } } } + + private void aggregatePayload(Frame frame) throws IOException + { + if (frame.hasPayload()) + { + ByteBuffer payload = frame.getPayload(); + size += payload.remaining(); + long maxBinaryMessageSize = session.getMaxBinaryMessageSize(); + if (maxBinaryMessageSize > 0 && size > maxBinaryMessageSize) + { + throw new MessageTooLargeException(String.format("Binary message too large: (actual) %,d > (configured max binary message size) %,d", + size, maxBinaryMessageSize)); + } + + if (out == null) + out = new ByteArrayOutputStream(BUFFER_SIZE); + + BufferUtil.writeTo(payload, out); + payload.position(payload.limit()); // consume buffer + } + } } diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/StringMessageSink.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/StringMessageSink.java index 5be932fc595..9f49405d17f 100644 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/StringMessageSink.java +++ b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/StringMessageSink.java @@ -18,22 +18,20 @@ package org.eclipse.jetty.websocket.util.messages; +import java.io.IOException; import java.lang.invoke.MethodHandle; import java.nio.ByteBuffer; -import org.eclipse.jetty.util.BufferUtil; import org.eclipse.jetty.util.Callback; import org.eclipse.jetty.util.Utf8StringBuilder; -import org.eclipse.jetty.util.log.Log; -import org.eclipse.jetty.util.log.Logger; import org.eclipse.jetty.websocket.core.CoreSession; import org.eclipse.jetty.websocket.core.Frame; import org.eclipse.jetty.websocket.core.exception.MessageTooLargeException; public class StringMessageSink extends AbstractMessageSink { - private static final Logger LOG = Log.getLogger(StringMessageSink.class); - private Utf8StringBuilder utf; + private static final int BUFFER_SIZE = 1024; + private Utf8StringBuilder out; private int size; public StringMessageSink(CoreSession session, MethodHandle methodHandle) @@ -42,52 +40,63 @@ public class StringMessageSink extends AbstractMessageSink this.size = 0; } - @SuppressWarnings("Duplicates") @Override public void accept(Frame frame, Callback callback) { try { - if (frame.hasPayload()) + // If we are fin and out has not been created we don't need to aggregate. + if (frame.isFin() && (out == null)) { - ByteBuffer payload = frame.getPayload(); - - size += payload.remaining(); - long maxTextMessageSize = session.getMaxTextMessageSize(); - if (maxTextMessageSize > 0 && size > maxTextMessageSize) - { - throw new MessageTooLargeException(String.format("Text message too large: (actual) %,d > (configured max text message size) %,d", - size, maxTextMessageSize)); - } - - if (utf == null) - utf = new Utf8StringBuilder(1024); - - if (LOG.isDebugEnabled()) - LOG.debug("Raw Payload {}", BufferUtil.toDetailString(payload)); - - // allow for fast fail of BAD utf (incomplete utf will trigger on messageComplete) - utf.append(payload); - } - - if (frame.isFin()) - { - // notify event - if (utf != null) - methodHandle.invoke(utf.toString()); + if (frame.hasPayload()) + methodHandle.invoke(frame.getPayloadAsUTF8()); else methodHandle.invoke(""); - // reset - size = 0; - utf = null; + callback.succeeded(); + return; } + aggregatePayload(frame); + if (frame.isFin()) + methodHandle.invoke(out.toString()); + callback.succeeded(); } catch (Throwable t) { callback.failed(t); } + finally + { + if (frame.isFin()) + { + // reset + size = 0; + out.reset(); + out = null; + } + } + } + + private void aggregatePayload(Frame frame) throws IOException + { + if (frame.hasPayload()) + { + ByteBuffer payload = frame.getPayload(); + size += frame.getPayloadLength(); + long maxTextMessageSize = session.getMaxTextMessageSize(); + if (maxTextMessageSize > 0 && size > maxTextMessageSize) + { + throw new MessageTooLargeException(String.format("Text message too large: (actual) %,d > (configured max text message size) %,d", + size, maxTextMessageSize)); + } + + if (out == null) + out = new Utf8StringBuilder(BUFFER_SIZE); + + // allow for fast fail of BAD utf (incomplete utf will trigger on messageComplete) + out.append(payload); + } } } From 71b11f088768358f358dfe8d32ec4e940dc0f27c Mon Sep 17 00:00:00 2001 From: Lachlan Roberts Date: Fri, 14 Feb 2020 12:32:20 +1100 Subject: [PATCH 2/4] Issue #4571 - simplify partial MessageSinks reduce copying Signed-off-by: Lachlan Roberts --- .../util/messages/DispatchedMessageSink.java | 15 +----- .../messages/PartialByteArrayMessageSink.java | 19 ++------ .../PartialByteBufferMessageSink.java | 18 ++----- .../messages/PartialStringMessageSink.java | 48 +------------------ .../util/messages/StringMessageSink.java | 1 - 5 files changed, 11 insertions(+), 90 deletions(-) diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/DispatchedMessageSink.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/DispatchedMessageSink.java index f4eab98533d..fdc8375354e 100644 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/DispatchedMessageSink.java +++ b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/DispatchedMessageSink.java @@ -136,20 +136,7 @@ public abstract class DispatchedMessageSink extends AbstractMessageSink if (frame.isFin()) { CompletableFuture finComplete = new CompletableFuture<>(); - frameCallback = new Callback() - { - @Override - public void failed(Throwable cause) - { - finComplete.completeExceptionally(cause); - } - - @Override - public void succeeded() - { - finComplete.complete(null); - } - }; + frameCallback = Callback.from(() -> finComplete.complete(null), finComplete::completeExceptionally); CompletableFuture.allOf(dispatchComplete, finComplete).whenComplete( (aVoid, throwable) -> { diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/PartialByteArrayMessageSink.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/PartialByteArrayMessageSink.java index 62e6bfe8f2a..b15c2c8c8b1 100644 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/PartialByteArrayMessageSink.java +++ b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/PartialByteArrayMessageSink.java @@ -20,7 +20,6 @@ package org.eclipse.jetty.websocket.util.messages; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodType; -import java.nio.ByteBuffer; import org.eclipse.jetty.util.BufferUtil; import org.eclipse.jetty.util.Callback; @@ -30,6 +29,8 @@ import org.eclipse.jetty.websocket.util.InvalidSignatureException; public class PartialByteArrayMessageSink extends AbstractMessageSink { + private static byte[] EMPTY_BUFFER = new byte[0]; + public PartialByteArrayMessageSink(CoreSession session, MethodHandle methodHandle) { super(session, methodHandle); @@ -48,22 +49,12 @@ public class PartialByteArrayMessageSink extends AbstractMessageSink { try { - byte[] buffer; - int offset = 0; - int length = 0; - - if (frame.hasPayload()) + if (frame.hasPayload() || frame.isFin()) { - ByteBuffer payload = frame.getPayload(); - length = payload.remaining(); - buffer = BufferUtil.toArray(payload); - } - else - { - buffer = new byte[0]; + byte[] buffer = frame.hasPayload() ? BufferUtil.toArray(frame.getPayload()) : EMPTY_BUFFER; + methodHandle.invoke(buffer, 0, buffer.length, frame.isFin()); } - methodHandle.invoke(buffer, offset, length, frame.isFin()); callback.succeeded(); } catch (Throwable t) diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/PartialByteBufferMessageSink.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/PartialByteBufferMessageSink.java index 70cf34b8376..55ed3bbcf05 100644 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/PartialByteBufferMessageSink.java +++ b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/PartialByteBufferMessageSink.java @@ -47,23 +47,11 @@ public class PartialByteBufferMessageSink extends AbstractMessageSink { try { - ByteBuffer buffer; - - if (frame.hasPayload()) + if (frame.hasPayload() || frame.isFin()) { - ByteBuffer payload = frame.getPayload(); - // copy buffer here - buffer = ByteBuffer.allocate(payload.remaining()); - BufferUtil.clearToFill(buffer); - BufferUtil.put(payload, buffer); - BufferUtil.flipToFlush(buffer, 0); + ByteBuffer buffer = frame.hasPayload() ? frame.getPayload() : BufferUtil.EMPTY_BUFFER; + methodHandle.invoke(buffer, frame.isFin()); } - else - { - buffer = BufferUtil.EMPTY_BUFFER; - } - - methodHandle.invoke(buffer, frame.isFin()); callback.succeeded(); } diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/PartialStringMessageSink.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/PartialStringMessageSink.java index c6451482f15..0a8e93883be 100644 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/PartialStringMessageSink.java +++ b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/PartialStringMessageSink.java @@ -19,28 +19,18 @@ package org.eclipse.jetty.websocket.util.messages; import java.lang.invoke.MethodHandle; -import java.nio.ByteBuffer; import java.util.Objects; -import org.eclipse.jetty.util.BufferUtil; import org.eclipse.jetty.util.Callback; -import org.eclipse.jetty.util.Utf8StringBuilder; -import org.eclipse.jetty.util.log.Log; -import org.eclipse.jetty.util.log.Logger; import org.eclipse.jetty.websocket.core.CoreSession; import org.eclipse.jetty.websocket.core.Frame; public class PartialStringMessageSink extends AbstractMessageSink { - private static final Logger LOG = Log.getLogger(PartialStringMessageSink.class); - private Utf8StringBuilder utf; - private int size; - public PartialStringMessageSink(CoreSession session, MethodHandle methodHandle) { super(session, methodHandle); Objects.requireNonNull(methodHandle, "MethodHandle"); - this.size = 0; } @SuppressWarnings("Duplicates") @@ -49,43 +39,9 @@ public class PartialStringMessageSink extends AbstractMessageSink { try { - if (utf == null) - utf = new Utf8StringBuilder(1024); - - if (frame.hasPayload()) + if (frame.hasPayload() || frame.isFin()) { - ByteBuffer payload = frame.getPayload(); - - //TODO we should fragment on maxTextMessageBufferSize not limit - //TODO also for PartialBinaryMessageSink - /* - if ((session.getMaxTextMessageBufferSize() > 0) && (size + payload.remaining() > session.getMaxTextMessageBufferSize())) - { - throw new MessageTooLargeException(String.format("Binary message too large: (actual) %,d > (configured max text buffer size) %,d", - size + payload.remaining(), session.getMaxTextMessageBufferSize())); - } - */ - - size += payload.remaining(); - - if (LOG.isDebugEnabled()) - LOG.debug("Raw Payload {}", BufferUtil.toDetailString(payload)); - - // allow for fast fail of BAD utf - utf.append(payload); - } - - if (frame.isFin()) - { - // Using toString to trigger failure on incomplete UTF-8 - methodHandle.invoke(utf.toString(), true); - // reset - size = 0; - utf = null; - } - else - { - methodHandle.invoke(utf.takePartialString(), false); + methodHandle.invoke(frame.getPayloadAsUTF8(), frame.isFin()); } callback.succeeded(); diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/StringMessageSink.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/StringMessageSink.java index 9f49405d17f..bd029bc7914 100644 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/StringMessageSink.java +++ b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/StringMessageSink.java @@ -73,7 +73,6 @@ public class StringMessageSink extends AbstractMessageSink { // reset size = 0; - out.reset(); out = null; } } From b2eddff2282f523fbd18c82aea46ad1bb7b5816f Mon Sep 17 00:00:00 2001 From: Lachlan Roberts Date: Wed, 19 Feb 2020 14:44:33 +1100 Subject: [PATCH 3/4] Issue #4571 - fix broken tests Signed-off-by: Lachlan Roberts --- .../jetty/io/ByteBufferPoolOutputStream.java | 111 ------------------ .../tests/client/MessageReceivingTest.java | 2 +- .../util/messages/ByteArrayMessageSink.java | 18 ++- .../util/messages/ByteBufferMessageSink.java | 15 +-- .../util/messages/StringMessageSink.java | 15 +-- 5 files changed, 25 insertions(+), 136 deletions(-) delete mode 100644 jetty-io/src/main/java/org/eclipse/jetty/io/ByteBufferPoolOutputStream.java diff --git a/jetty-io/src/main/java/org/eclipse/jetty/io/ByteBufferPoolOutputStream.java b/jetty-io/src/main/java/org/eclipse/jetty/io/ByteBufferPoolOutputStream.java deleted file mode 100644 index 3dfa67791c6..00000000000 --- a/jetty-io/src/main/java/org/eclipse/jetty/io/ByteBufferPoolOutputStream.java +++ /dev/null @@ -1,111 +0,0 @@ -// -// ======================================================================== -// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others. -// -// This program and the accompanying materials are made available under -// the terms of the Eclipse Public License 2.0 which is available at -// https://www.eclipse.org/legal/epl-2.0 -// -// This Source Code may also be made available under the following -// Secondary Licenses when the conditions for such availability set -// forth in the Eclipse Public License, v. 2.0 are satisfied: -// the Apache License v2.0 which is available at -// https://www.apache.org/licenses/LICENSE-2.0 -// -// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 -// ======================================================================== -// - -package org.eclipse.jetty.io; - -import java.io.IOException; -import java.io.OutputStream; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Objects; - -import org.eclipse.jetty.util.BufferUtil; - -public class ByteBufferPoolOutputStream extends OutputStream -{ - private final ByteBufferPool bufferPool; - private final ArrayList buffers; - private final boolean direct; - private final int acquireSize; - - private ByteBuffer aggregateBuffer; - private int size = 0; - - public ByteBufferPoolOutputStream(ByteBufferPool bufferPool, int acquireSize, boolean direct) - { - this.buffers = new ArrayList<>(); - this.direct = direct; - this.bufferPool = Objects.requireNonNull(bufferPool); - this.acquireSize = acquireSize; - if (acquireSize <= 0) - throw new IllegalArgumentException(); - - this.buffers.add(bufferPool.acquire(acquireSize, direct)); - } - - @Override - public void write(int b) throws IOException - { - write(new byte[]{(byte)b}, 0, 1); - } - - @Override - public void write(byte[] b, int off, int len) throws IOException - { - write(ByteBuffer.wrap(b, off, len)); - } - - public void write(ByteBuffer data) - { - while (data.hasRemaining()) - { - ByteBuffer buffer = buffers.get(buffers.size() - 1); - size += BufferUtil.append(buffer, data); - if (!buffer.hasRemaining()) - buffers.add(bufferPool.acquire(acquireSize, direct)); - } - } - - public int size() - { - return size; - } - - public ByteBuffer toByteBuffer() - { - releaseAggregate(); - aggregateBuffer = bufferPool.acquire(size, direct); - for (ByteBuffer data : buffers) - { - BufferUtil.append(aggregateBuffer, data); - } - return aggregateBuffer; - } - - public byte[] toByteArray() - { - return BufferUtil.toArray(toByteBuffer()); - } - - private void releaseAggregate() - { - if (aggregateBuffer != null) - { - bufferPool.release(aggregateBuffer); - aggregateBuffer = null; - } - } - - @Override - public void close() - { - releaseAggregate(); - for (ByteBuffer buffer : buffers) - bufferPool.release(buffer); - } -} diff --git a/jetty-websocket/websocket-javax-tests/src/test/java/org/eclipse/jetty/websocket/javax/tests/client/MessageReceivingTest.java b/jetty-websocket/websocket-javax-tests/src/test/java/org/eclipse/jetty/websocket/javax/tests/client/MessageReceivingTest.java index c901acc48d2..6da56e34586 100644 --- a/jetty-websocket/websocket-javax-tests/src/test/java/org/eclipse/jetty/websocket/javax/tests/client/MessageReceivingTest.java +++ b/jetty-websocket/websocket-javax-tests/src/test/java/org/eclipse/jetty/websocket/javax/tests/client/MessageReceivingTest.java @@ -477,7 +477,7 @@ public class MessageReceivingTest @Override public void onMessage(ByteBuffer message) { - final String stringResult = new String(message.array()); + final String stringResult = BufferUtil.toString(message); messageQueue.offer(stringResult); } } diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ByteArrayMessageSink.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ByteArrayMessageSink.java index 4763ee16946..6e02ef13c4b 100644 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ByteArrayMessageSink.java +++ b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ByteArrayMessageSink.java @@ -56,6 +56,14 @@ public class ByteArrayMessageSink extends AbstractMessageSink { try { + size += frame.getPayloadLength(); + long maxBinaryMessageSize = session.getMaxBinaryMessageSize(); + if (maxBinaryMessageSize > 0 && size > maxBinaryMessageSize) + { + throw new MessageTooLargeException(String.format("Binary message too large: (actual) %,d > (configured max binary message size) %,d", + size, maxBinaryMessageSize)); + } + // If we are fin and no OutputStream has been created we don't need to aggregate. if (frame.isFin() && (out == null)) { @@ -71,7 +79,6 @@ public class ByteArrayMessageSink extends AbstractMessageSink return; } - aggregatePayload(frame); if (frame.isFin()) { @@ -100,17 +107,8 @@ public class ByteArrayMessageSink extends AbstractMessageSink if (frame.hasPayload()) { ByteBuffer payload = frame.getPayload(); - size += payload.remaining(); - long maxBinaryMessageSize = session.getMaxBinaryMessageSize(); - if (maxBinaryMessageSize > 0 && size > maxBinaryMessageSize) - { - throw new MessageTooLargeException(String.format("Binary message too large: (actual) %,d > (configured max binary message size) %,d", - size, maxBinaryMessageSize)); - } - if (out == null) out = new ByteArrayOutputStream(BUFFER_SIZE); - BufferUtil.writeTo(payload, out); } } diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ByteBufferMessageSink.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ByteBufferMessageSink.java index 1b4d6094701..85f9f8e2af3 100644 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ByteBufferMessageSink.java +++ b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ByteBufferMessageSink.java @@ -57,6 +57,14 @@ public class ByteBufferMessageSink extends AbstractMessageSink { try { + size += frame.getPayloadLength(); + long maxBinaryMessageSize = session.getMaxBinaryMessageSize(); + if (maxBinaryMessageSize > 0 && size > maxBinaryMessageSize) + { + throw new MessageTooLargeException(String.format("Binary message too large: (actual) %,d > (configured max binary message size) %,d", + size, maxBinaryMessageSize)); + } + // If we are fin and no OutputStream has been created we don't need to aggregate. if (frame.isFin() && (out == null)) { @@ -95,13 +103,6 @@ public class ByteBufferMessageSink extends AbstractMessageSink if (frame.hasPayload()) { ByteBuffer payload = frame.getPayload(); - size += payload.remaining(); - long maxBinaryMessageSize = session.getMaxBinaryMessageSize(); - if (maxBinaryMessageSize > 0 && size > maxBinaryMessageSize) - { - throw new MessageTooLargeException(String.format("Binary message too large: (actual) %,d > (configured max binary message size) %,d", - size, maxBinaryMessageSize)); - } if (out == null) out = new ByteArrayOutputStream(BUFFER_SIZE); diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/StringMessageSink.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/StringMessageSink.java index bd029bc7914..10c2e59022e 100644 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/StringMessageSink.java +++ b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/StringMessageSink.java @@ -45,6 +45,14 @@ public class StringMessageSink extends AbstractMessageSink { try { + size += frame.getPayloadLength(); + long maxTextMessageSize = session.getMaxTextMessageSize(); + if (maxTextMessageSize > 0 && size > maxTextMessageSize) + { + throw new MessageTooLargeException(String.format("Text message too large: (actual) %,d > (configured max text message size) %,d", + size, maxTextMessageSize)); + } + // If we are fin and out has not been created we don't need to aggregate. if (frame.isFin() && (out == null)) { @@ -83,13 +91,6 @@ public class StringMessageSink extends AbstractMessageSink if (frame.hasPayload()) { ByteBuffer payload = frame.getPayload(); - size += frame.getPayloadLength(); - long maxTextMessageSize = session.getMaxTextMessageSize(); - if (maxTextMessageSize > 0 && size > maxTextMessageSize) - { - throw new MessageTooLargeException(String.format("Text message too large: (actual) %,d > (configured max text message size) %,d", - size, maxTextMessageSize)); - } if (out == null) out = new Utf8StringBuilder(BUFFER_SIZE); From adbb3f165ec2615ce21163754fb4865551153359 Mon Sep 17 00:00:00 2001 From: Lachlan Roberts Date: Thu, 20 Feb 2020 16:55:02 +1100 Subject: [PATCH 4/4] Issue #4571 - review UTF-8 validation in MessageSinks Signed-off-by: Lachlan Roberts --- .../eclipse/jetty/websocket/core/Frame.java | 6 + .../util/messages/ByteArrayMessageSink.java | 1 - .../util/messages/ByteBufferMessageSink.java | 1 - .../messages/PartialByteArrayMessageSink.java | 1 - .../PartialByteBufferMessageSink.java | 8 +- .../messages/PartialStringMessageSink.java | 17 +- .../util/messages/StringMessageSink.java | 32 +--- .../util/PartialStringMessageSinkTest.java | 153 ++++++++++++++++++ .../websocket/util/StringMessageSinkTest.java | 147 +++++++++++++++++ 9 files changed, 324 insertions(+), 42 deletions(-) create mode 100644 jetty-websocket/websocket-util/src/test/org/eclipse/jetty/websocket/util/PartialStringMessageSinkTest.java create mode 100644 jetty-websocket/websocket-util/src/test/org/eclipse/jetty/websocket/util/StringMessageSinkTest.java diff --git a/jetty-websocket/websocket-core/src/main/java/org/eclipse/jetty/websocket/core/Frame.java b/jetty-websocket/websocket-core/src/main/java/org/eclipse/jetty/websocket/core/Frame.java index 0224eb8de73..5cd15ad0f33 100644 --- a/jetty-websocket/websocket-core/src/main/java/org/eclipse/jetty/websocket/core/Frame.java +++ b/jetty-websocket/websocket-core/src/main/java/org/eclipse/jetty/websocket/core/Frame.java @@ -221,6 +221,12 @@ public class Frame return payload; } + /** + * Get the payload of the frame as a UTF-8 string. + *

Should only be used in testing, does not validate the + * UTF-8 and a non fin frame can contain partial UTF-8 characters.

+ * @return the payload as a UTF-8 string. + */ public String getPayloadAsUTF8() { if (payload == null) diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ByteArrayMessageSink.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ByteArrayMessageSink.java index 6e02ef13c4b..1921e36b704 100644 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ByteArrayMessageSink.java +++ b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ByteArrayMessageSink.java @@ -50,7 +50,6 @@ public class ByteArrayMessageSink extends AbstractMessageSink } } - @SuppressWarnings("Duplicates") @Override public void accept(Frame frame, Callback callback) { diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ByteBufferMessageSink.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ByteBufferMessageSink.java index 85f9f8e2af3..0c255e4e267 100644 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ByteBufferMessageSink.java +++ b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ByteBufferMessageSink.java @@ -51,7 +51,6 @@ public class ByteBufferMessageSink extends AbstractMessageSink } } - @SuppressWarnings("Duplicates") @Override public void accept(Frame frame, Callback callback) { diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/PartialByteArrayMessageSink.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/PartialByteArrayMessageSink.java index b15c2c8c8b1..a42c3577837 100644 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/PartialByteArrayMessageSink.java +++ b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/PartialByteArrayMessageSink.java @@ -43,7 +43,6 @@ public class PartialByteArrayMessageSink extends AbstractMessageSink } } - @SuppressWarnings("Duplicates") @Override public void accept(Frame frame, Callback callback) { diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/PartialByteBufferMessageSink.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/PartialByteBufferMessageSink.java index 55ed3bbcf05..ce00ee7acef 100644 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/PartialByteBufferMessageSink.java +++ b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/PartialByteBufferMessageSink.java @@ -19,9 +19,7 @@ package org.eclipse.jetty.websocket.util.messages; import java.lang.invoke.MethodHandle; -import java.nio.ByteBuffer; -import org.eclipse.jetty.util.BufferUtil; import org.eclipse.jetty.util.Callback; import org.eclipse.jetty.websocket.core.CoreSession; import org.eclipse.jetty.websocket.core.Frame; @@ -41,17 +39,13 @@ public class PartialByteBufferMessageSink extends AbstractMessageSink */ } - @SuppressWarnings("Duplicates") @Override public void accept(Frame frame, Callback callback) { try { if (frame.hasPayload() || frame.isFin()) - { - ByteBuffer buffer = frame.hasPayload() ? frame.getPayload() : BufferUtil.EMPTY_BUFFER; - methodHandle.invoke(buffer, frame.isFin()); - } + methodHandle.invoke(frame.getPayload(), frame.isFin()); callback.succeeded(); } diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/PartialStringMessageSink.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/PartialStringMessageSink.java index 0a8e93883be..457bb1461ea 100644 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/PartialStringMessageSink.java +++ b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/PartialStringMessageSink.java @@ -22,26 +22,37 @@ import java.lang.invoke.MethodHandle; import java.util.Objects; import org.eclipse.jetty.util.Callback; +import org.eclipse.jetty.util.Utf8StringBuilder; import org.eclipse.jetty.websocket.core.CoreSession; import org.eclipse.jetty.websocket.core.Frame; public class PartialStringMessageSink extends AbstractMessageSink { + private Utf8StringBuilder out; + public PartialStringMessageSink(CoreSession session, MethodHandle methodHandle) { super(session, methodHandle); Objects.requireNonNull(methodHandle, "MethodHandle"); } - @SuppressWarnings("Duplicates") @Override public void accept(Frame frame, Callback callback) { try { - if (frame.hasPayload() || frame.isFin()) + if (out == null) + out = new Utf8StringBuilder(session.getInputBufferSize()); + + out.append(frame.getPayload()); + if (frame.isFin()) { - methodHandle.invoke(frame.getPayloadAsUTF8(), frame.isFin()); + methodHandle.invoke(out.toString(), true); + out = null; + } + else + { + methodHandle.invoke(out.takePartialString(), false); } callback.succeeded(); diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/StringMessageSink.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/StringMessageSink.java index 10c2e59022e..78fef1afb1b 100644 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/StringMessageSink.java +++ b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/StringMessageSink.java @@ -18,9 +18,7 @@ package org.eclipse.jetty.websocket.util.messages; -import java.io.IOException; import java.lang.invoke.MethodHandle; -import java.nio.ByteBuffer; import org.eclipse.jetty.util.Callback; import org.eclipse.jetty.util.Utf8StringBuilder; @@ -30,7 +28,6 @@ import org.eclipse.jetty.websocket.core.exception.MessageTooLargeException; public class StringMessageSink extends AbstractMessageSink { - private static final int BUFFER_SIZE = 1024; private Utf8StringBuilder out; private int size; @@ -53,19 +50,10 @@ public class StringMessageSink extends AbstractMessageSink size, maxTextMessageSize)); } - // If we are fin and out has not been created we don't need to aggregate. - if (frame.isFin() && (out == null)) - { - if (frame.hasPayload()) - methodHandle.invoke(frame.getPayloadAsUTF8()); - else - methodHandle.invoke(""); + if (out == null) + out = new Utf8StringBuilder(session.getInputBufferSize()); - callback.succeeded(); - return; - } - - aggregatePayload(frame); + out.append(frame.getPayload()); if (frame.isFin()) methodHandle.invoke(out.toString()); @@ -85,18 +73,4 @@ public class StringMessageSink extends AbstractMessageSink } } } - - private void aggregatePayload(Frame frame) throws IOException - { - if (frame.hasPayload()) - { - ByteBuffer payload = frame.getPayload(); - - if (out == null) - out = new Utf8StringBuilder(BUFFER_SIZE); - - // allow for fast fail of BAD utf (incomplete utf will trigger on messageComplete) - out.append(payload); - } - } } diff --git a/jetty-websocket/websocket-util/src/test/org/eclipse/jetty/websocket/util/PartialStringMessageSinkTest.java b/jetty-websocket/websocket-util/src/test/org/eclipse/jetty/websocket/util/PartialStringMessageSinkTest.java new file mode 100644 index 00000000000..1cb9e174fc8 --- /dev/null +++ b/jetty-websocket/websocket-util/src/test/org/eclipse/jetty/websocket/util/PartialStringMessageSinkTest.java @@ -0,0 +1,153 @@ +// +// ======================================================================== +// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others. +// +// This program and the accompanying materials are made available under +// the terms of the Eclipse Public License 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0 +// +// This Source Code may also be made available under the following +// Secondary Licenses when the conditions for such availability set +// forth in the Eclipse Public License, v. 2.0 are satisfied: +// the Apache License v2.0 which is available at +// https://www.apache.org/licenses/LICENSE-2.0 +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +// ======================================================================== +// + +package org.eclipse.jetty.websocket.util; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.TimeUnit; + +import org.eclipse.jetty.util.BlockingArrayQueue; +import org.eclipse.jetty.util.BufferUtil; +import org.eclipse.jetty.util.FutureCallback; +import org.eclipse.jetty.util.Utf8Appendable; +import org.eclipse.jetty.websocket.core.CoreSession; +import org.eclipse.jetty.websocket.core.Frame; +import org.eclipse.jetty.websocket.core.OpCode; +import org.eclipse.jetty.websocket.util.messages.PartialStringMessageSink; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class PartialStringMessageSinkTest +{ + private CoreSession coreSession = new CoreSession.Empty(); + private OnMessageEndpoint endpoint = new OnMessageEndpoint(); + private PartialStringMessageSink messageSink; + + @BeforeEach + public void before() throws Exception + { + messageSink = new PartialStringMessageSink(coreSession, endpoint.getMethodHandle()); + } + + @Test + public void testValidUtf8() throws Exception + { + ByteBuffer utf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90, (byte)0x8D, (byte)0x88}); + + FutureCallback callback = new FutureCallback(); + messageSink.accept(new Frame(OpCode.TEXT, utf8Payload).setFin(true), callback); + callback.block(5, TimeUnit.SECONDS); + + List message = Objects.requireNonNull(endpoint.messages.poll(5, TimeUnit.SECONDS)); + assertThat(message.size(), is(1)); + assertThat(message.get(0), is("\uD800\uDF48")); + } + + @Test + public void testUtf8Continuation() throws Exception + { + ByteBuffer firstUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90}); + ByteBuffer continuationUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0x8D, (byte)0x88}); + + FutureCallback callback = new FutureCallback(); + messageSink.accept(new Frame(OpCode.TEXT, firstUtf8Payload).setFin(false), callback); + callback.block(5, TimeUnit.SECONDS); + + callback = new FutureCallback(); + messageSink.accept(new Frame(OpCode.TEXT, continuationUtf8Payload).setFin(true), callback); + callback.block(5, TimeUnit.SECONDS); + + List message = Objects.requireNonNull(endpoint.messages.poll(5, TimeUnit.SECONDS)); + assertThat(message.size(), is(2)); + assertThat(message.get(0), is("")); + assertThat(message.get(1), is("\uD800\uDF48")); + } + + @Test + public void testInvalidSingleFrameUtf8() throws Exception + { + ByteBuffer invalidUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90, (byte)0x8D}); + + FutureCallback callback = new FutureCallback(); + messageSink.accept(new Frame(OpCode.TEXT, invalidUtf8Payload).setFin(true), callback); + + // Callback should fail and we don't receive the message in the sink. + RuntimeException error = assertThrows(RuntimeException.class, () -> callback.block(5, TimeUnit.SECONDS)); + assertThat(error.getCause(), instanceOf(Utf8Appendable.NotUtf8Exception.class)); + List message = Objects.requireNonNull(endpoint.messages.poll(5, TimeUnit.SECONDS)); + assertTrue(message.isEmpty()); + } + + @Test + public void testInvalidMultiFrameUtf8() throws Exception + { + ByteBuffer firstUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90}); + ByteBuffer continuationUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0x8D}); + + FutureCallback firstCallback = new FutureCallback(); + messageSink.accept(new Frame(OpCode.TEXT, firstUtf8Payload).setFin(false), firstCallback); + firstCallback.block(5, TimeUnit.SECONDS); + + FutureCallback continuationCallback = new FutureCallback(); + messageSink.accept(new Frame(OpCode.TEXT, continuationUtf8Payload).setFin(true), continuationCallback); + + // Callback should fail and we only received the first frame which had no full character. + RuntimeException error = assertThrows(RuntimeException.class, () -> continuationCallback.block(5, TimeUnit.SECONDS)); + assertThat(error.getCause(), instanceOf(Utf8Appendable.NotUtf8Exception.class)); + List message = Objects.requireNonNull(endpoint.messages.poll(5, TimeUnit.SECONDS)); + assertThat(message.size(), is(1)); + assertThat(message.get(0), is("")); + } + + public static class OnMessageEndpoint + { + private BlockingArrayQueue> messages; + + public OnMessageEndpoint() + { + messages = new BlockingArrayQueue<>(); + messages.add(new ArrayList<>()); + } + + public void onMessage(String message, boolean last) + { + messages.get(messages.size() - 1).add(message); + if (last) + messages.add(new ArrayList<>()); + } + + public MethodHandle getMethodHandle() throws Exception + { + return MethodHandles.lookup() + .findVirtual(this.getClass(), "onMessage", MethodType.methodType(void.class, String.class, boolean.class)) + .bindTo(this); + } + } +} diff --git a/jetty-websocket/websocket-util/src/test/org/eclipse/jetty/websocket/util/StringMessageSinkTest.java b/jetty-websocket/websocket-util/src/test/org/eclipse/jetty/websocket/util/StringMessageSinkTest.java new file mode 100644 index 00000000000..f876c52a942 --- /dev/null +++ b/jetty-websocket/websocket-util/src/test/org/eclipse/jetty/websocket/util/StringMessageSinkTest.java @@ -0,0 +1,147 @@ +// +// ======================================================================== +// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others. +// +// This program and the accompanying materials are made available under +// the terms of the Eclipse Public License 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0 +// +// This Source Code may also be made available under the following +// Secondary Licenses when the conditions for such availability set +// forth in the Eclipse Public License, v. 2.0 are satisfied: +// the Apache License v2.0 which is available at +// https://www.apache.org/licenses/LICENSE-2.0 +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +// ======================================================================== +// + +package org.eclipse.jetty.websocket.util; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.nio.ByteBuffer; +import java.util.concurrent.TimeUnit; + +import org.eclipse.jetty.util.BlockingArrayQueue; +import org.eclipse.jetty.util.BufferUtil; +import org.eclipse.jetty.util.FutureCallback; +import org.eclipse.jetty.util.Utf8Appendable; +import org.eclipse.jetty.websocket.core.CoreSession; +import org.eclipse.jetty.websocket.core.Frame; +import org.eclipse.jetty.websocket.core.OpCode; +import org.eclipse.jetty.websocket.core.exception.MessageTooLargeException; +import org.eclipse.jetty.websocket.util.messages.StringMessageSink; +import org.junit.jupiter.api.Test; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class StringMessageSinkTest +{ + private CoreSession coreSession = new CoreSession.Empty(); + private OnMessageEndpoint endpoint = new OnMessageEndpoint(); + + @Test + public void testMaxMessageSize() throws Exception + { + StringMessageSink messageSink = new StringMessageSink(coreSession, endpoint.getMethodHandle()); + ByteBuffer utf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90, (byte)0x8D, (byte)0x88}); + + FutureCallback callback = new FutureCallback(); + coreSession.setMaxTextMessageSize(3); + messageSink.accept(new Frame(OpCode.TEXT, utf8Payload).setFin(true), callback); + + // Callback should fail and we don't receive the message in the sink. + RuntimeException error = assertThrows(RuntimeException.class, () -> callback.block(5, TimeUnit.SECONDS)); + assertThat(error.getCause(), instanceOf(MessageTooLargeException.class)); + assertNull(endpoint.messages.poll()); + } + + @Test + public void testValidUtf8() throws Exception + { + StringMessageSink messageSink = new StringMessageSink(coreSession, endpoint.getMethodHandle()); + ByteBuffer utf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90, (byte)0x8D, (byte)0x88}); + + FutureCallback callback = new FutureCallback(); + messageSink.accept(new Frame(OpCode.TEXT, utf8Payload).setFin(true), callback); + callback.block(5, TimeUnit.SECONDS); + + assertThat(endpoint.messages.poll(5, TimeUnit.SECONDS), is("\uD800\uDF48")); + } + + @Test + public void testUtf8Continuation() throws Exception + { + StringMessageSink messageSink = new StringMessageSink(coreSession, endpoint.getMethodHandle()); + ByteBuffer firstUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90}); + ByteBuffer continuationUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0x8D, (byte)0x88}); + + FutureCallback callback = new FutureCallback(); + messageSink.accept(new Frame(OpCode.TEXT, firstUtf8Payload).setFin(false), callback); + callback.block(5, TimeUnit.SECONDS); + + callback = new FutureCallback(); + messageSink.accept(new Frame(OpCode.TEXT, continuationUtf8Payload).setFin(true), callback); + callback.block(5, TimeUnit.SECONDS); + + assertThat(endpoint.messages.poll(5, TimeUnit.SECONDS), is("\uD800\uDF48")); + } + + @Test + public void testInvalidSingleFrameUtf8() throws Exception + { + StringMessageSink messageSink = new StringMessageSink(coreSession, endpoint.getMethodHandle()); + ByteBuffer invalidUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90, (byte)0x8D}); + + FutureCallback callback = new FutureCallback(); + messageSink.accept(new Frame(OpCode.TEXT, invalidUtf8Payload).setFin(true), callback); + + // Callback should fail and we don't receive the message in the sink. + RuntimeException error = assertThrows(RuntimeException.class, () -> callback.block(5, TimeUnit.SECONDS)); + assertThat(error.getCause(), instanceOf(Utf8Appendable.NotUtf8Exception.class)); + assertNull(endpoint.messages.poll()); + } + + @Test + public void testInvalidMultiFrameUtf8() throws Exception + { + StringMessageSink messageSink = new StringMessageSink(coreSession, endpoint.getMethodHandle()); + ByteBuffer firstUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0xF0, (byte)0x90}); + ByteBuffer continuationUtf8Payload = BufferUtil.toBuffer(new byte[]{(byte)0x8D}); + + FutureCallback firstCallback = new FutureCallback(); + messageSink.accept(new Frame(OpCode.TEXT, firstUtf8Payload).setFin(false), firstCallback); + firstCallback.block(5, TimeUnit.SECONDS); + + FutureCallback continuationCallback = new FutureCallback(); + messageSink.accept(new Frame(OpCode.TEXT, continuationUtf8Payload).setFin(true), continuationCallback); + + // Callback should fail and we don't receive the message in the sink. + RuntimeException error = assertThrows(RuntimeException.class, () -> continuationCallback.block(5, TimeUnit.SECONDS)); + assertThat(error.getCause(), instanceOf(Utf8Appendable.NotUtf8Exception.class)); + assertNull(endpoint.messages.poll()); + } + + public static class OnMessageEndpoint + { + private BlockingArrayQueue messages = new BlockingArrayQueue<>(); + + public void onMessage(String message) + { + messages.add(message); + } + + public MethodHandle getMethodHandle() throws Exception + { + return MethodHandles.lookup() + .findVirtual(this.getClass(), "onMessage", MethodType.methodType(void.class, String.class)) + .bindTo(this); + } + } +}