From 026261f48290694db35935bf605fb53303fe8803 Mon Sep 17 00:00:00 2001 From: Lachlan Roberts Date: Wed, 18 Aug 2021 17:10:12 +1000 Subject: [PATCH] Issue #6566 - use counter in BufferCallbackAccumulator, fix InputStreamMessageSinkTest failures Signed-off-by: Lachlan Roberts --- .../jetty/io/BufferCallbackAccumulator.java | 16 ++++-- .../messages/ByteArrayMessageSink.java | 26 ++++----- .../messages/ByteBufferMessageSink.java | 20 +++---- .../javax/common/AbstractSessionTest.java | 55 +++++++++++------- .../messages/AbstractMessageSinkTest.java | 5 +- .../messages/InputStreamMessageSinkTest.java | 57 ++++++++++++------- 6 files changed, 101 insertions(+), 78 deletions(-) diff --git a/jetty-io/src/main/java/org/eclipse/jetty/io/BufferCallbackAccumulator.java b/jetty-io/src/main/java/org/eclipse/jetty/io/BufferCallbackAccumulator.java index 6459254ce13..982349e2066 100644 --- a/jetty-io/src/main/java/org/eclipse/jetty/io/BufferCallbackAccumulator.java +++ b/jetty-io/src/main/java/org/eclipse/jetty/io/BufferCallbackAccumulator.java @@ -24,6 +24,7 @@ import org.eclipse.jetty.util.Callback; public class BufferCallbackAccumulator { private final List _entries = new ArrayList<>(); + private int _length; private static class Entry { @@ -40,6 +41,7 @@ public class BufferCallbackAccumulator public void addEntry(ByteBuffer buffer, Callback callback) { _entries.add(new Entry(buffer, callback)); + _length = Math.addExact(_length, buffer.remaining()); } /** @@ -49,16 +51,13 @@ public class BufferCallbackAccumulator */ public int getLength() { - int length = 0; - for (Entry entry : _entries) - length = Math.addExact(length, entry.buffer.remaining()); - return length; + return _length; } /** * @return a newly allocated byte array containing all content written into the accumulator. */ - public byte[] toByteArray() + public byte[] takeByteArray() { int length = getLength(); if (length == 0) @@ -76,10 +75,16 @@ public class BufferCallbackAccumulator for (Iterator iterator = _entries.iterator(); iterator.hasNext();) { Entry entry = iterator.next(); + _length = entry.buffer.remaining(); buffer.put(entry.buffer); iterator.remove(); entry.callback.succeeded(); } + + if (!_entries.isEmpty()) + throw new IllegalStateException("remaining entries: " + _entries.size()); + if (_length != 0) + throw new IllegalStateException("non-zero length: " + _length); } public void fail(Throwable t) @@ -89,5 +94,6 @@ public class BufferCallbackAccumulator entry.callback.failed(t); } _entries.clear(); + _length = 0; } } diff --git a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/ByteArrayMessageSink.java b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/ByteArrayMessageSink.java index 66d59d1d9e3..259feecb017 100644 --- a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/ByteArrayMessageSink.java +++ b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/ByteArrayMessageSink.java @@ -53,8 +53,8 @@ public class ByteArrayMessageSink extends AbstractMessageSink long maxBinaryMessageSize = session.getMaxBinaryMessageSize(); if (maxBinaryMessageSize > 0 && size > maxBinaryMessageSize) { - callback.failed(new MessageTooLargeException( - String.format("Binary message too large: (actual) %,d > (configured max binary message size) %,d", size, maxBinaryMessageSize))); + throw new MessageTooLargeException( + String.format("Binary message too large: (actual) %,d > (configured max binary message size) %,d", size, maxBinaryMessageSize)); } // If we are fin and no OutputStream has been created we don't need to aggregate. @@ -73,13 +73,20 @@ public class ByteArrayMessageSink extends AbstractMessageSink return; } - aggregatePayload(frame, callback); + // Aggregate the frame payload. + if (frame.hasPayload()) + { + ByteBuffer payload = frame.getPayload(); + if (out == null) + out = new BufferCallbackAccumulator(); + out.addEntry(payload, callback); + } // If the methodHandle throws we don't want to fail callback twice. callback = Callback.NOOP; if (frame.isFin()) { - byte[] buf = out.toByteArray(); + byte[] buf = out.takeByteArray(); methodHandle.invoke(buf, 0, buf.length); } @@ -101,15 +108,4 @@ public class ByteArrayMessageSink extends AbstractMessageSink } } } - - private void aggregatePayload(Frame frame, Callback callback) - { - if (frame.hasPayload()) - { - ByteBuffer payload = frame.getPayload(); - if (out == null) - out = new BufferCallbackAccumulator(); - out.addEntry(payload, callback); - } - } } diff --git a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/ByteBufferMessageSink.java b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/ByteBufferMessageSink.java index 3b94c1692b9..25ccde49c02 100644 --- a/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/ByteBufferMessageSink.java +++ b/jetty-websocket/websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/messages/ByteBufferMessageSink.java @@ -71,7 +71,14 @@ public class ByteBufferMessageSink extends AbstractMessageSink return; } - aggregatePayload(frame, callback); + // Aggregate the frame payload. + if (frame.hasPayload()) + { + ByteBuffer payload = frame.getPayload(); + if (out == null) + out = new BufferCallbackAccumulator(); + out.addEntry(payload, callback); + } // If the methodHandle throws we don't want to fail callback twice. callback = Callback.NOOP; @@ -110,15 +117,4 @@ public class ByteBufferMessageSink extends AbstractMessageSink } } } - - private void aggregatePayload(Frame frame, Callback callback) - { - if (frame.hasPayload()) - { - ByteBuffer payload = frame.getPayload(); - if (out == null) - out = new BufferCallbackAccumulator(); - out.addEntry(payload, callback); - } - } } diff --git a/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/AbstractSessionTest.java b/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/AbstractSessionTest.java index 709ef0ef491..9fac9ab9bc1 100644 --- a/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/AbstractSessionTest.java +++ b/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/AbstractSessionTest.java @@ -13,48 +13,35 @@ package org.eclipse.jetty.websocket.javax.common; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; import javax.websocket.Endpoint; import javax.websocket.EndpointConfig; import javax.websocket.Session; import org.eclipse.jetty.io.ByteBufferPool; -import org.eclipse.jetty.io.NullByteBufferPool; import org.eclipse.jetty.websocket.core.CoreSession; import org.eclipse.jetty.websocket.core.WebSocketComponents; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; +import static org.junit.jupiter.api.Assertions.assertTrue; + public abstract class AbstractSessionTest { protected static JavaxWebSocketSession session; - protected static JavaxWebSocketContainer container; - protected static WebSocketComponents components; + protected static JavaxWebSocketContainer container = new DummyContainer(); + protected static WebSocketComponents components = new WebSocketComponents(); + protected static TestCoreSession coreSession = new TestCoreSession(); @BeforeAll public static void initSession() throws Exception { - container = new DummyContainer(); container.start(); - components = new WebSocketComponents(); components.start(); Object websocketPojo = new DummyEndpoint(); UpgradeRequest upgradeRequest = new UpgradeRequestAdapter(); JavaxWebSocketFrameHandler frameHandler = container.newFrameHandler(websocketPojo, upgradeRequest); - ByteBufferPool bufferPool = new NullByteBufferPool(); - CoreSession coreSession = new CoreSession.Empty() - { - @Override - public WebSocketComponents getWebSocketComponents() - { - return components; - } - - @Override - public ByteBufferPool getByteBufferPool() - { - return bufferPool; - } - }; session = new JavaxWebSocketSession(container, coreSession, frameHandler, container.getFrameHandlerFactory() .newDefaultEndpointConfig(websocketPojo.getClass())); } @@ -66,6 +53,34 @@ public abstract class AbstractSessionTest container.stop(); } + public static class TestCoreSession extends CoreSession.Empty + { + private final Semaphore demand = new Semaphore(0); + + @Override + public WebSocketComponents getWebSocketComponents() + { + return components; + } + + @Override + public ByteBufferPool getByteBufferPool() + { + return components.getBufferPool(); + } + + public void waitForDemand(long timeout, TimeUnit timeUnit) throws InterruptedException + { + assertTrue(demand.tryAcquire(timeout, timeUnit)); + } + + @Override + public void demand(long n) + { + demand.release(); + } + } + public static class DummyEndpoint extends Endpoint { @Override diff --git a/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/messages/AbstractMessageSinkTest.java b/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/messages/AbstractMessageSinkTest.java index 83682c9a413..4c6281c74b5 100644 --- a/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/messages/AbstractMessageSinkTest.java +++ b/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/messages/AbstractMessageSinkTest.java @@ -20,15 +20,12 @@ import java.util.function.Consumer; import javax.websocket.ClientEndpointConfig; import javax.websocket.Decoder; -import org.eclipse.jetty.websocket.core.WebSocketComponents; import org.eclipse.jetty.websocket.javax.common.AbstractSessionTest; import org.eclipse.jetty.websocket.javax.common.JavaxWebSocketFrameHandlerFactory; import org.eclipse.jetty.websocket.javax.common.decoders.RegisteredDecoder; public abstract class AbstractMessageSinkTest extends AbstractSessionTest { - private final WebSocketComponents _components = new WebSocketComponents(); - public List toRegisteredDecoderList(Class clazz, Class objectType) { Class interfaceType; @@ -43,7 +40,7 @@ public abstract class AbstractMessageSinkTest extends AbstractSessionTest else throw new IllegalStateException(); - return List.of(new RegisteredDecoder(clazz, interfaceType, objectType, ClientEndpointConfig.Builder.create().build(), _components)); + return List.of(new RegisteredDecoder(clazz, interfaceType, objectType, ClientEndpointConfig.Builder.create().build(), components)); } public MethodHandle getAcceptHandle(Consumer copy, Class type) diff --git a/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/messages/InputStreamMessageSinkTest.java b/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/messages/InputStreamMessageSinkTest.java index 7e5d26b3180..ac8c82c443d 100644 --- a/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/messages/InputStreamMessageSinkTest.java +++ b/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/messages/InputStreamMessageSinkTest.java @@ -18,6 +18,7 @@ import java.io.IOException; import java.io.InputStream; import java.lang.invoke.MethodHandle; import java.nio.ByteBuffer; +import java.util.Objects; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -51,10 +52,11 @@ public class InputStreamMessageSinkTest extends AbstractMessageSinkTest ByteBuffer data = BufferUtil.toBuffer("Hello World", UTF_8); sink.accept(new Frame(OpCode.BINARY).setPayload(data), finCallback); - finCallback.get(1, TimeUnit.SECONDS); // wait for callback + coreSession.waitForDemand(1, TimeUnit.SECONDS); + finCallback.get(1, TimeUnit.SECONDS); ByteArrayOutputStream byteStream = copy.poll(1, TimeUnit.SECONDS); assertThat("FinCallback.done", finCallback.isDone(), is(true)); - assertThat("Writer.contents", new String(byteStream.toByteArray(), UTF_8), is("Hello World")); + assertThat("Writer.contents", byteStream.toString(UTF_8), is("Hello World")); } @Test @@ -68,19 +70,22 @@ public class InputStreamMessageSinkTest extends AbstractMessageSinkTest ByteBuffer data1 = BufferUtil.toBuffer("Hello World", UTF_8); sink.accept(new Frame(OpCode.BINARY).setPayload(data1).setFin(true), fin1Callback); - fin1Callback.get(1, TimeUnit.SECONDS); // wait for callback (can't sent next message until this callback finishes) + // wait for demand (can't sent next message until a new frame is demanded) + coreSession.waitForDemand(1, TimeUnit.SECONDS); + fin1Callback.get(1, TimeUnit.SECONDS); ByteArrayOutputStream byteStream = copy.poll(1, TimeUnit.SECONDS); assertThat("FinCallback.done", fin1Callback.isDone(), is(true)); - assertThat("Writer.contents", new String(byteStream.toByteArray(), UTF_8), is("Hello World")); + assertThat("Writer.contents", byteStream.toString(UTF_8), is("Hello World")); FutureCallback fin2Callback = new FutureCallback(); ByteBuffer data2 = BufferUtil.toBuffer("Greetings Earthling", UTF_8); sink.accept(new Frame(OpCode.BINARY).setPayload(data2).setFin(true), fin2Callback); - fin2Callback.get(1, TimeUnit.SECONDS); // wait for callback + coreSession.waitForDemand(1, TimeUnit.SECONDS); + fin2Callback.get(1, TimeUnit.SECONDS); byteStream = copy.poll(1, TimeUnit.SECONDS); assertThat("FinCallback.done", fin2Callback.isDone(), is(true)); - assertThat("Writer.contents", new String(byteStream.toByteArray(), UTF_8), is("Greetings Earthling")); + assertThat("Writer.contents", byteStream.toString(UTF_8), is("Greetings Earthling")); } @Test @@ -95,16 +100,19 @@ public class InputStreamMessageSinkTest extends AbstractMessageSinkTest FutureCallback finCallback = new FutureCallback(); sink.accept(new Frame(OpCode.BINARY).setPayload("Hello").setFin(false), callback1); - sink.accept(new Frame(OpCode.CONTINUATION).setPayload(", ").setFin(false), callback2); - sink.accept(new Frame(OpCode.CONTINUATION).setPayload("World").setFin(true), finCallback); + coreSession.waitForDemand(1, TimeUnit.SECONDS); + assertThat("callback1.done", callback1.isDone(), is(true)); - finCallback.get(1, TimeUnit.SECONDS); // wait for callback - ByteArrayOutputStream byteStream = copy.poll(1, TimeUnit.SECONDS); - assertThat("Callback1.done", callback1.isDone(), is(true)); - assertThat("Callback2.done", callback2.isDone(), is(true)); + sink.accept(new Frame(OpCode.CONTINUATION).setPayload(", ").setFin(false), callback2); + coreSession.waitForDemand(1, TimeUnit.SECONDS); + assertThat("callback2.done", callback2.isDone(), is(true)); + + sink.accept(new Frame(OpCode.CONTINUATION).setPayload("World").setFin(true), finCallback); + coreSession.waitForDemand(1, TimeUnit.SECONDS); assertThat("finCallback.done", finCallback.isDone(), is(true)); - assertThat("Writer.contents", new String(byteStream.toByteArray(), UTF_8), is("Hello, World")); + ByteArrayOutputStream byteStream = copy.poll(1, TimeUnit.SECONDS); + assertThat("Writer.contents", byteStream.toString(UTF_8), is("Hello, World")); } @Test @@ -120,18 +128,23 @@ public class InputStreamMessageSinkTest extends AbstractMessageSinkTest FutureCallback finCallback = new FutureCallback(); sink.accept(new Frame(OpCode.BINARY).setPayload("Greetings").setFin(false), callback1); - sink.accept(new Frame(OpCode.CONTINUATION).setPayload(", ").setFin(false), callback2); - sink.accept(new Frame(OpCode.CONTINUATION).setPayload("Earthling").setFin(false), callback3); - sink.accept(new Frame(OpCode.CONTINUATION).setPayload(new byte[0]).setFin(true), finCallback); - - finCallback.get(5, TimeUnit.SECONDS); // wait for callback - ByteArrayOutputStream byteStream = copy.poll(1, TimeUnit.SECONDS); + coreSession.waitForDemand(1, TimeUnit.SECONDS); assertThat("Callback1.done", callback1.isDone(), is(true)); + + sink.accept(new Frame(OpCode.CONTINUATION).setPayload(", ").setFin(false), callback2); + coreSession.waitForDemand(1, TimeUnit.SECONDS); assertThat("Callback2.done", callback2.isDone(), is(true)); + + sink.accept(new Frame(OpCode.CONTINUATION).setPayload("Earthling").setFin(false), callback3); + coreSession.waitForDemand(1, TimeUnit.SECONDS); assertThat("Callback3.done", callback3.isDone(), is(true)); + + sink.accept(new Frame(OpCode.CONTINUATION).setPayload(new byte[0]).setFin(true), finCallback); + coreSession.waitForDemand(1, TimeUnit.SECONDS); assertThat("finCallback.done", finCallback.isDone(), is(true)); - assertThat("Writer.contents", new String(byteStream.toByteArray(), UTF_8), is("Greetings, Earthling")); + ByteArrayOutputStream byteStream = copy.poll(1, TimeUnit.SECONDS); + assertThat("Writer.contents", byteStream.toString(UTF_8), is("Greetings, Earthling")); } public static class InputStreamCopy implements Consumer @@ -156,9 +169,9 @@ public class InputStreamMessageSinkTest extends AbstractMessageSinkTest } } - public ByteArrayOutputStream poll(long time, TimeUnit unit) throws InterruptedException, ExecutionException + public ByteArrayOutputStream poll(long time, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { - return streams.poll(time, unit).get(); + return Objects.requireNonNull(streams.poll(time, unit)).get(time, unit); } } }