From 5c839d791d9914a035ccf18f0c658e9f4821f26d Mon Sep 17 00:00:00 2001 From: Lachlan Roberts Date: Wed, 19 Feb 2020 19:31:26 +1100 Subject: [PATCH] Issue #4538 - allow MessageInputStream not to read to EOF Signed-off-by: Lachlan Roberts --- .../org/eclipse/jetty/util/StringUtil.java | 16 +++ .../javax/tests/server/TextStreamTest.java | 107 +++++++++++------- .../util/messages/DispatchedMessageSink.java | 5 + .../util/messages/MessageInputStream.java | 26 +++-- 4 files changed, 98 insertions(+), 56 deletions(-) diff --git a/jetty-util/src/main/java/org/eclipse/jetty/util/StringUtil.java b/jetty-util/src/main/java/org/eclipse/jetty/util/StringUtil.java index d59a34763b0..77aebf33779 100644 --- a/jetty-util/src/main/java/org/eclipse/jetty/util/StringUtil.java +++ b/jetty-util/src/main/java/org/eclipse/jetty/util/StringUtil.java @@ -438,6 +438,22 @@ public class StringUtil } } + /** + * Generate a string from another string repeated n times. + * + * @param s the string to use + * @param n the number of times this string should be appended + */ + public static String stringFrom(String s, int n) + { + StringBuilder stringBuilder = new StringBuilder(s.length() * n); + for (int i = 0; i < n; i++) + { + stringBuilder.append(s); + } + return stringBuilder.toString(); + } + /** * Return a non null string. * diff --git a/jetty-websocket/websocket-javax-tests/src/test/java/org/eclipse/jetty/websocket/javax/tests/server/TextStreamTest.java b/jetty-websocket/websocket-javax-tests/src/test/java/org/eclipse/jetty/websocket/javax/tests/server/TextStreamTest.java index 543ca5ecabd..e420db47bbc 100644 --- a/jetty-websocket/websocket-javax-tests/src/test/java/org/eclipse/jetty/websocket/javax/tests/server/TextStreamTest.java +++ b/jetty-websocket/websocket-javax-tests/src/test/java/org/eclipse/jetty/websocket/javax/tests/server/TextStreamTest.java @@ -26,12 +26,12 @@ import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; +import java.util.Objects; import java.util.Random; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import javax.websocket.ClientEndpoint; +import javax.websocket.ClientEndpointConfig; import javax.websocket.ContainerProvider; -import javax.websocket.Endpoint; import javax.websocket.EndpointConfig; import javax.websocket.MessageHandler; import javax.websocket.OnMessage; @@ -43,14 +43,17 @@ import javax.websocket.server.ServerEndpointConfig; import org.eclipse.jetty.util.BlockingArrayQueue; import org.eclipse.jetty.util.IO; +import org.eclipse.jetty.util.StringUtil; import org.eclipse.jetty.util.log.Log; import org.eclipse.jetty.util.log.Logger; import org.eclipse.jetty.websocket.core.CloseStatus; import org.eclipse.jetty.websocket.core.Frame; import org.eclipse.jetty.websocket.core.OpCode; +import org.eclipse.jetty.websocket.javax.common.JavaxWebSocketSession; import org.eclipse.jetty.websocket.javax.tests.DataUtils; import org.eclipse.jetty.websocket.javax.tests.Fuzzer; import org.eclipse.jetty.websocket.javax.tests.LocalServer; +import org.eclipse.jetty.websocket.javax.tests.WSEndpointTracker; import org.hamcrest.Matchers; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -59,12 +62,15 @@ import org.junit.jupiter.api.Test; import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TextStreamTest { private static final Logger LOG = Log.getLogger(TextStreamTest.class); private static final BlockingArrayQueue serverEndpoints = new BlockingArrayQueue<>(); + private final ClientEndpointConfig clientConfig = ClientEndpointConfig.Builder.create().build(); private LocalServer server; private ServerContainer container; private WebSocketContainer wsClient; @@ -172,7 +178,7 @@ public class TextStreamTest public void testMessageOrdering() throws Exception { ClientTextStreamer client = new ClientTextStreamer(); - Session session = wsClient.connectToServer(client, server.getWsUri().resolve("/test")); + Session session = wsClient.connectToServer(client, clientConfig, server.getWsUri().resolve("/test")); final int numLoops = 20; for (int i = 0; i < numLoops; i++) @@ -194,7 +200,7 @@ public class TextStreamTest public void testFragmentedMessageOrdering() throws Exception { ClientTextStreamer client = new ClientTextStreamer(); - Session session = wsClient.connectToServer(client, server.getWsUri().resolve("/test")); + Session session = wsClient.connectToServer(client, clientConfig, server.getWsUri().resolve("/test")); final int numLoops = 20; for (int i = 0; i < numLoops; i++) @@ -218,52 +224,68 @@ public class TextStreamTest @Test public void testMessageOrderingDoNotReadToEOF() throws Exception { - ClientTextStreamer client = new ClientTextStreamer(); - Session session = wsClient.connectToServer(client, server.getWsUri().resolve("/partial")); + ClientTextStreamer clientEndpoint = new ClientTextStreamer(); + Session session = wsClient.connectToServer(clientEndpoint, clientConfig, server.getWsUri().resolve("/partial")); + QueuedTextStreamer serverEndpoint = Objects.requireNonNull(serverEndpoints.poll(5, TimeUnit.SECONDS)); + + int serverInputBufferSize = 1024; + JavaxWebSocketSession serverSession = (JavaxWebSocketSession)serverEndpoint.session; + serverSession.getCoreSession().setInputBufferSize(serverInputBufferSize); + + // Write some initial data. + Writer writer = session.getBasicRemote().getSendWriter(); + writer.write("first frame"); + writer.flush(); + + // Signal to stop reading. + writer.write("|"); + writer.flush(); + + // Lots of data after we have stopped reading and onMessage exits. + final String largePayload = StringUtil.stringFrom("x", serverInputBufferSize * 2); + writer.write(largePayload); + writer.close(); - final int numLoops = 20; - for (int i = 0; i < numLoops; i++) - { - session.getBasicRemote().sendText(i + "|-----"); - } session.close(); + assertTrue(clientEndpoint.closeLatch.await(5, TimeUnit.SECONDS)); + assertTrue(serverEndpoint.closeLatch.await(5, TimeUnit.SECONDS)); + assertNull(clientEndpoint.error.get()); + assertNull(serverEndpoint.error.get()); - QueuedTextStreamer queuedTextStreamer = serverEndpoints.poll(5, TimeUnit.SECONDS); - assertNotNull(queuedTextStreamer); - for (int i = 0; i < numLoops; i++) - { - String msg = queuedTextStreamer.messages.poll(5, TimeUnit.SECONDS); - assertThat(msg, Matchers.is(Integer.toString(i))); - } + String msg = serverEndpoint.messages.poll(5, TimeUnit.SECONDS); + assertThat(msg, Matchers.is("first frame")); } - @ClientEndpoint - public static class ClientTextStreamer + public static class ClientTextStreamer extends WSEndpointTracker implements MessageHandler.Whole { private final CountDownLatch latch = new CountDownLatch(1); private final StringBuilder output = new StringBuilder(); - @OnMessage - public void echoed(Reader input) throws IOException + @Override + public void onOpen(Session session, EndpointConfig config) { - while (true) + session.addMessageHandler(this); + super.onOpen(session, config); + } + + @Override + public void onMessage(Reader input) + { + try { - int read = input.read(); - if (read < 0) - break; - output.append((char)read); + while (true) + { + int read = input.read(); + if (read < 0) + break; + output.append((char)read); + } + latch.countDown(); + } + catch (IOException e) + { + throw new RuntimeException(e); } - latch.countDown(); - } - - public char[] getEcho() - { - return output.toString().toCharArray(); - } - - public boolean await(long timeout, TimeUnit unit) throws InterruptedException - { - return latch.await(timeout, unit); } } @@ -289,19 +311,16 @@ public class TextStreamTest } } - public static class QueuedTextStreamer extends Endpoint implements MessageHandler.Whole + public static class QueuedTextStreamer extends WSEndpointTracker implements MessageHandler.Whole { protected BlockingArrayQueue messages = new BlockingArrayQueue<>(); - public QueuedTextStreamer() - { - serverEndpoints.add(this); - } - @Override public void onOpen(Session session, EndpointConfig config) { session.addMessageHandler(this); + super.onOpen(session, config); + serverEndpoints.add(this); } @Override diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/DispatchedMessageSink.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/DispatchedMessageSink.java index 1d6678b3bcc..4060d4b13d2 100644 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/DispatchedMessageSink.java +++ b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/DispatchedMessageSink.java @@ -18,6 +18,7 @@ package org.eclipse.jetty.websocket.util.messages; +import java.io.Closeable; import java.lang.invoke.MethodHandle; import java.util.concurrent.CompletableFuture; @@ -120,6 +121,10 @@ public abstract class DispatchedMessageSink extends AbstractMessageSink { methodHandle.invoke(typeSink); dispatchComplete.complete(null); + + // If the MessageSink can be closed do this to free up resources. + if (typeSink instanceof Closeable) + ((Closeable)typeSink).close(); } catch (Throwable throwable) { diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/MessageInputStream.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/MessageInputStream.java index 3ceada2005c..6b6528ee19e 100644 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/MessageInputStream.java +++ b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/MessageInputStream.java @@ -102,14 +102,14 @@ public class MessageInputStream extends InputStream implements MessageSink public int read(ByteBuffer buffer) throws IOException { - Entry result = getCurrentEntry(); + Entry currentEntry = getCurrentEntry(); if (LOG.isDebugEnabled()) - LOG.debug("result = {}", result); + LOG.debug("currentEntry = {}", currentEntry); - if (result == CLOSED) + if (currentEntry == CLOSED) throw new IOException("Closed"); - if (result == EOF) + if (currentEntry == EOF) { if (LOG.isDebugEnabled()) LOG.debug("Read EOF"); @@ -117,11 +117,13 @@ public class MessageInputStream extends InputStream implements MessageSink } // We have content. - int fillLen = BufferUtil.append(buffer, result.buffer); - if (!result.buffer.hasRemaining()) + int fillLen = BufferUtil.append(buffer, currentEntry.buffer); + if (!currentEntry.buffer.hasRemaining()) succeedCurrentEntry(); // Return number of bytes actually copied into buffer. + if (LOG.isDebugEnabled()) + LOG.debug("filled {} bytes from {}", fillLen, currentEntry); return fillLen; } @@ -131,7 +133,7 @@ public class MessageInputStream extends InputStream implements MessageSink if (LOG.isDebugEnabled()) LOG.debug("close()"); - ArrayList failedEntries = new ArrayList<>(); + ArrayList entries = new ArrayList<>(); synchronized (this) { if (closed) @@ -140,20 +142,20 @@ public class MessageInputStream extends InputStream implements MessageSink if (currentEntry != null) { - failedEntries.add(currentEntry); + entries.add(currentEntry); currentEntry = null; } // Clear queue and fail all entries. - failedEntries.addAll(buffers); + entries.addAll(buffers); buffers.clear(); buffers.offer(CLOSED); } - Throwable cause = new IOException("Closed"); - for (Entry e : failedEntries) + // Succeed all entries as we don't need them anymore (failing would close the connection). + for (Entry e : entries) { - e.callback.failed(cause); + e.callback.succeeded(); } super.close();