From 5236e47c42dd367c5202ebb4c1201dda9f493acc Mon Sep 17 00:00:00 2001 From: Lachlan Roberts Date: Wed, 18 Aug 2021 11:01:35 +1000 Subject: [PATCH] Issue #6566 - utilise the demand interface in the websocket MessageSinks Signed-off-by: Lachlan Roberts --- .../jetty/io/BufferCallbackAccumulator.java | 21 ++------- .../messages/ByteArrayMessageSink.java | 27 ++++++----- .../messages/ByteBufferMessageSink.java | 46 +++++++++++++------ .../messages/DispatchedMessageSink.java | 22 +++++---- .../messages/PartialByteArrayMessageSink.java | 1 + .../PartialByteBufferMessageSink.java | 1 + .../messages/PartialStringMessageSink.java | 1 + .../internal/messages/StringMessageSink.java | 1 + .../common/JavaxWebSocketFrameHandler.java | 9 ++++ .../javax/common/AbstractSessionTest.java | 9 ++++ .../common/JettyWebSocketFrameHandler.java | 20 +++----- 11 files changed, 94 insertions(+), 64 deletions(-) diff --git a/jetty-io/src/main/java/org/eclipse/jetty/io/BufferCallbackAccumulator.java b/jetty-io/src/main/java/org/eclipse/jetty/io/BufferCallbackAccumulator.java index 9aaa662beff..6459254ce13 100644 --- a/jetty-io/src/main/java/org/eclipse/jetty/io/BufferCallbackAccumulator.java +++ b/jetty-io/src/main/java/org/eclipse/jetty/io/BufferCallbackAccumulator.java @@ -24,30 +24,17 @@ import org.eclipse.jetty.util.Callback; public class BufferCallbackAccumulator { private final List _entries = new ArrayList<>(); - private final ByteBufferPool _bufferPool; - private final boolean _direct; private static class Entry { + private final ByteBuffer buffer; + private final Callback callback; + Entry(ByteBuffer buffer, Callback callback) { this.buffer = buffer; this.callback = callback; } - - ByteBuffer buffer; - Callback callback; - } - - public BufferCallbackAccumulator() - { - this(null, false); - } - - BufferCallbackAccumulator(ByteBufferPool bufferPool, boolean direct) - { - _bufferPool = (bufferPool == null) ? new NullByteBufferPool() : bufferPool; - _direct = direct; } public void addEntry(ByteBuffer buffer, Callback callback) @@ -86,7 +73,7 @@ public class BufferCallbackAccumulator public void writeTo(ByteBuffer buffer) { - for (Iterator iterator = _entries.iterator(); iterator.hasNext(); ) + for (Iterator iterator = _entries.iterator(); iterator.hasNext();) { Entry entry = iterator.next(); buffer.put(entry.buffer); diff --git a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/ByteArrayMessageSink.java b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/ByteArrayMessageSink.java index 1eeda1230dc..66d59d1d9e3 100644 --- a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/ByteArrayMessageSink.java +++ b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/ByteArrayMessageSink.java @@ -13,12 +13,11 @@ package org.eclipse.jetty.websocket.core.internal.messages; -import java.io.ByteArrayOutputStream; -import java.io.IOException; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodType; import java.nio.ByteBuffer; +import org.eclipse.jetty.io.BufferCallbackAccumulator; import org.eclipse.jetty.util.BufferUtil; import org.eclipse.jetty.util.Callback; import org.eclipse.jetty.websocket.core.CoreSession; @@ -29,8 +28,7 @@ import org.eclipse.jetty.websocket.core.exception.MessageTooLargeException; public class ByteArrayMessageSink extends AbstractMessageSink { private static final byte[] EMPTY_BUFFER = new byte[0]; - private static final int BUFFER_SIZE = 65535; - private ByteArrayOutputStream out; + private BufferCallbackAccumulator out; private int size; public ByteArrayMessageSink(CoreSession session, MethodHandle methodHandle) @@ -55,8 +53,8 @@ public class ByteArrayMessageSink extends AbstractMessageSink long maxBinaryMessageSize = session.getMaxBinaryMessageSize(); if (maxBinaryMessageSize > 0 && size > maxBinaryMessageSize) { - throw new MessageTooLargeException(String.format("Binary message too large: (actual) %,d > (configured max binary message size) %,d", - size, maxBinaryMessageSize)); + callback.failed(new MessageTooLargeException( + String.format("Binary message too large: (actual) %,d > (configured max binary message size) %,d", size, maxBinaryMessageSize))); } // If we are fin and no OutputStream has been created we don't need to aggregate. @@ -71,19 +69,26 @@ public class ByteArrayMessageSink extends AbstractMessageSink methodHandle.invoke(EMPTY_BUFFER, 0, 0); callback.succeeded(); + session.demand(1); return; } - aggregatePayload(frame); + aggregatePayload(frame, callback); + + // If the methodHandle throws we don't want to fail callback twice. + callback = Callback.NOOP; if (frame.isFin()) { byte[] buf = out.toByteArray(); methodHandle.invoke(buf, 0, buf.length); } - callback.succeeded(); + + session.demand(1); } catch (Throwable t) { + if (out != null) + out.fail(t); callback.failed(t); } finally @@ -97,14 +102,14 @@ public class ByteArrayMessageSink extends AbstractMessageSink } } - private void aggregatePayload(Frame frame) throws IOException + private void aggregatePayload(Frame frame, Callback callback) { if (frame.hasPayload()) { ByteBuffer payload = frame.getPayload(); if (out == null) - out = new ByteArrayOutputStream(BUFFER_SIZE); - BufferUtil.writeTo(payload, out); + out = new BufferCallbackAccumulator(); + out.addEntry(payload, callback); } } } diff --git a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/ByteBufferMessageSink.java b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/ByteBufferMessageSink.java index 677406d36b7..0130a1c45d6 100644 --- a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/ByteBufferMessageSink.java +++ b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/ByteBufferMessageSink.java @@ -13,13 +13,13 @@ package org.eclipse.jetty.websocket.core.internal.messages; -import java.io.ByteArrayOutputStream; -import java.io.IOException; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodType; import java.nio.ByteBuffer; import java.util.Objects; +import org.eclipse.jetty.io.BufferCallbackAccumulator; +import org.eclipse.jetty.io.ByteBufferPool; import org.eclipse.jetty.util.BufferUtil; import org.eclipse.jetty.util.Callback; import org.eclipse.jetty.websocket.core.CoreSession; @@ -29,8 +29,7 @@ import org.eclipse.jetty.websocket.core.exception.MessageTooLargeException; public class ByteBufferMessageSink extends AbstractMessageSink { - private static final int BUFFER_SIZE = 65535; - private ByteArrayOutputStream out; + private BufferCallbackAccumulator out; private int size; public ByteBufferMessageSink(CoreSession session, MethodHandle methodHandle) @@ -68,41 +67,58 @@ public class ByteBufferMessageSink extends AbstractMessageSink methodHandle.invoke(BufferUtil.EMPTY_BUFFER); callback.succeeded(); + session.demand(1); return; } - aggregatePayload(frame); - if (frame.isFin()) - methodHandle.invoke(ByteBuffer.wrap(out.toByteArray())); + aggregatePayload(frame, callback); - callback.succeeded(); + // If the methodHandle throws we don't want to fail callback twice. + callback = Callback.NOOP; + if (frame.isFin()) + { + ByteBufferPool bufferPool = session.getByteBufferPool(); + ByteBuffer buffer = bufferPool.acquire(out.getLength(), false); + BufferUtil.clearToFill(buffer); + out.writeTo(buffer); + BufferUtil.flipToFlush(buffer, 0); + + try + { + methodHandle.invoke(buffer); + } + finally + { + bufferPool.release(buffer); + } + + session.demand(1); + } } catch (Throwable t) { + if (out != null) + out.fail(t); callback.failed(t); } finally { if (frame.isFin()) { - // reset out = null; size = 0; } } } - private void aggregatePayload(Frame frame) throws IOException + private void aggregatePayload(Frame frame, Callback callback) { if (frame.hasPayload()) { ByteBuffer payload = frame.getPayload(); - if (out == null) - out = new ByteArrayOutputStream(BUFFER_SIZE); - - BufferUtil.writeTo(payload, out); - payload.position(payload.limit()); // consume buffer + out = new BufferCallbackAccumulator(); + out.addEntry(payload, callback); } } } diff --git a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/DispatchedMessageSink.java b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/DispatchedMessageSink.java index 57d31e0d3f6..dfa37741597 100644 --- a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/DispatchedMessageSink.java +++ b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/DispatchedMessageSink.java @@ -135,22 +135,28 @@ public abstract class DispatchedMessageSink extends AbstractMessageSink }); } - Callback frameCallback = callback; + Callback frameCallback; if (frame.isFin()) { // This is the final frame we should wait for the frame callback and the dispatched thread. - Callback.Completable completableCallback = new Callback.Completable(); - frameCallback = completableCallback; - CompletableFuture.allOf(dispatchComplete, completableCallback).whenComplete((aVoid, throwable) -> + Callback.Completable finComplete = Callback.Completable.from(callback); + frameCallback = finComplete; + CompletableFuture.allOf(dispatchComplete, finComplete).whenComplete((aVoid, throwable) -> { typeSink = null; dispatchComplete = null; - if (throwable != null) - callback.failed(throwable); - else - callback.succeeded(); + if (throwable == null) + session.demand(1); }); } + else + { + frameCallback = Callback.from(() -> + { + callback.succeeded(); + session.demand(1); + }, callback::failed); + } typeSink.accept(frame, frameCallback); } diff --git a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/PartialByteArrayMessageSink.java b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/PartialByteArrayMessageSink.java index 59a39984030..076ef5f5c37 100644 --- a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/PartialByteArrayMessageSink.java +++ b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/PartialByteArrayMessageSink.java @@ -41,6 +41,7 @@ public class PartialByteArrayMessageSink extends AbstractMessageSink } callback.succeeded(); + session.demand(1); } catch (Throwable t) { diff --git a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/PartialByteBufferMessageSink.java b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/PartialByteBufferMessageSink.java index f463645e794..0331c477c7c 100644 --- a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/PartialByteBufferMessageSink.java +++ b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/PartialByteBufferMessageSink.java @@ -35,6 +35,7 @@ public class PartialByteBufferMessageSink extends AbstractMessageSink methodHandle.invoke(frame.getPayload(), frame.isFin()); callback.succeeded(); + session.demand(1); } catch (Throwable t) { diff --git a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/PartialStringMessageSink.java b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/PartialStringMessageSink.java index 6e01aa3327f..9061884c7fd 100644 --- a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/PartialStringMessageSink.java +++ b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/PartialStringMessageSink.java @@ -51,6 +51,7 @@ public class PartialStringMessageSink extends AbstractMessageSink } callback.succeeded(); + session.demand(1); } catch (Throwable t) { diff --git a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/StringMessageSink.java b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/StringMessageSink.java index 6cad40d2279..810e4330a11 100644 --- a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/StringMessageSink.java +++ b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/StringMessageSink.java @@ -53,6 +53,7 @@ public class StringMessageSink extends AbstractMessageSink methodHandle.invoke(out.toString()); callback.succeeded(); + session.demand(1); } catch (Throwable t) { diff --git a/jetty-websocket/websocket-javax-common/src/main/java/org/eclipse/jetty/websocket/javax/common/JavaxWebSocketFrameHandler.java b/jetty-websocket/websocket-javax-common/src/main/java/org/eclipse/jetty/websocket/javax/common/JavaxWebSocketFrameHandler.java index 03031217824..24fb6d45167 100644 --- a/jetty-websocket/websocket-javax-common/src/main/java/org/eclipse/jetty/websocket/javax/common/JavaxWebSocketFrameHandler.java +++ b/jetty-websocket/websocket-javax-common/src/main/java/org/eclipse/jetty/websocket/javax/common/JavaxWebSocketFrameHandler.java @@ -178,6 +178,7 @@ public class JavaxWebSocketFrameHandler implements FrameHandler container.notifySessionListeners((listener) -> listener.onJavaxWebSocketSessionOpened(session)); callback.succeeded(); + coreSession.demand(1); } catch (Throwable cause) { @@ -321,6 +322,12 @@ public class JavaxWebSocketFrameHandler implements FrameHandler } } + @Override + public boolean isDemanding() + { + return true; + } + public Set getMessageHandlers() { return messageHandlerMap.values().stream() @@ -591,6 +598,7 @@ public class JavaxWebSocketFrameHandler implements FrameHandler ByteBuffer payload = BufferUtil.copy(frame.getPayload()); coreSession.sendFrame(new Frame(OpCode.PONG).setPayload(payload), Callback.NOOP, false); callback.succeeded(); + coreSession.demand(1); } public void onPong(Frame frame, Callback callback) @@ -613,6 +621,7 @@ public class JavaxWebSocketFrameHandler implements FrameHandler } } callback.succeeded(); + coreSession.demand(1); } public void onText(Frame frame, Callback callback) diff --git a/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/AbstractSessionTest.java b/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/AbstractSessionTest.java index 30fa05edad7..709ef0ef491 100644 --- a/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/AbstractSessionTest.java +++ b/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/AbstractSessionTest.java @@ -17,6 +17,8 @@ import javax.websocket.Endpoint; import javax.websocket.EndpointConfig; import javax.websocket.Session; +import org.eclipse.jetty.io.ByteBufferPool; +import org.eclipse.jetty.io.NullByteBufferPool; import org.eclipse.jetty.websocket.core.CoreSession; import org.eclipse.jetty.websocket.core.WebSocketComponents; import org.junit.jupiter.api.AfterAll; @@ -38,6 +40,7 @@ public abstract class AbstractSessionTest Object websocketPojo = new DummyEndpoint(); UpgradeRequest upgradeRequest = new UpgradeRequestAdapter(); JavaxWebSocketFrameHandler frameHandler = container.newFrameHandler(websocketPojo, upgradeRequest); + ByteBufferPool bufferPool = new NullByteBufferPool(); CoreSession coreSession = new CoreSession.Empty() { @Override @@ -45,6 +48,12 @@ public abstract class AbstractSessionTest { return components; } + + @Override + public ByteBufferPool getByteBufferPool() + { + return bufferPool; + } }; session = new JavaxWebSocketSession(container, coreSession, frameHandler, container.getFrameHandlerFactory() .newDefaultEndpointConfig(websocketPojo.getClass())); diff --git a/jetty-websocket/websocket-jetty-common/src/main/java/org/eclipse/jetty/websocket/common/JettyWebSocketFrameHandler.java b/jetty-websocket/websocket-jetty-common/src/main/java/org/eclipse/jetty/websocket/common/JettyWebSocketFrameHandler.java index 992c9042dd2..d5584ff6f3d 100644 --- a/jetty-websocket/websocket-jetty-common/src/main/java/org/eclipse/jetty/websocket/common/JettyWebSocketFrameHandler.java +++ b/jetty-websocket/websocket-jetty-common/src/main/java/org/eclipse/jetty/websocket/common/JettyWebSocketFrameHandler.java @@ -82,6 +82,7 @@ public class JettyWebSocketFrameHandler implements FrameHandler private WebSocketSession session; private SuspendState state = SuspendState.DEMANDING; private Runnable delayedOnFrame; + private CoreSession coreSession; public JettyWebSocketFrameHandler(WebSocketContainer container, Object endpointInstance, @@ -150,6 +151,7 @@ public class JettyWebSocketFrameHandler implements FrameHandler try { customizer.customize(coreSession); + this.coreSession = coreSession; session = new WebSocketSession(container, coreSession, this); if (!session.isOpen()) throw new IllegalStateException("Session is not open"); @@ -226,16 +228,7 @@ public class JettyWebSocketFrameHandler implements FrameHandler // Demand after succeeding any received frame Callback demandingCallback = Callback.from(() -> { - try - { - demand(); - } - catch (Throwable t) - { - callback.failed(t); - return; - } - + demand(); callback.succeeded(); }, callback::failed @@ -253,13 +246,13 @@ public class JettyWebSocketFrameHandler implements FrameHandler onPongFrame(frame, demandingCallback); break; case OpCode.TEXT: - onTextFrame(frame, demandingCallback); + onTextFrame(frame, callback); break; case OpCode.BINARY: - onBinaryFrame(frame, demandingCallback); + onBinaryFrame(frame, callback); break; case OpCode.CONTINUATION: - onContinuationFrame(frame, demandingCallback); + onContinuationFrame(frame, callback); break; default: callback.failed(new IllegalStateException()); @@ -342,6 +335,7 @@ public class JettyWebSocketFrameHandler implements FrameHandler if (activeMessageSink == null) { callback.succeeded(); + coreSession.demand(1); return; }