Issue #4538 - allow MessageInputStream not to read to EOF

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
Lachlan Roberts 2020-02-19 19:31:26 +11:00
parent 97abed549b
commit 5c839d791d
4 changed files with 98 additions and 56 deletions

View File

@ -438,6 +438,22 @@ public class StringUtil
}
}
/**
* Generate a string from another string repeated n times.
*
* @param s the string to use
* @param n the number of times this string should be appended
*/
public static String stringFrom(String s, int n)
{
StringBuilder stringBuilder = new StringBuilder(s.length() * n);
for (int i = 0; i < n; i++)
{
stringBuilder.append(s);
}
return stringBuilder.toString();
}
/**
* Return a non null string.
*

View File

@ -26,12 +26,12 @@ import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.websocket.ClientEndpoint;
import javax.websocket.ClientEndpointConfig;
import javax.websocket.ContainerProvider;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.MessageHandler;
import javax.websocket.OnMessage;
@ -43,14 +43,17 @@ import javax.websocket.server.ServerEndpointConfig;
import org.eclipse.jetty.util.BlockingArrayQueue;
import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.websocket.core.CloseStatus;
import org.eclipse.jetty.websocket.core.Frame;
import org.eclipse.jetty.websocket.core.OpCode;
import org.eclipse.jetty.websocket.javax.common.JavaxWebSocketSession;
import org.eclipse.jetty.websocket.javax.tests.DataUtils;
import org.eclipse.jetty.websocket.javax.tests.Fuzzer;
import org.eclipse.jetty.websocket.javax.tests.LocalServer;
import org.eclipse.jetty.websocket.javax.tests.WSEndpointTracker;
import org.hamcrest.Matchers;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
@ -59,12 +62,15 @@ import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TextStreamTest
{
private static final Logger LOG = Log.getLogger(TextStreamTest.class);
private static final BlockingArrayQueue<QueuedTextStreamer> serverEndpoints = new BlockingArrayQueue<>();
private final ClientEndpointConfig clientConfig = ClientEndpointConfig.Builder.create().build();
private LocalServer server;
private ServerContainer container;
private WebSocketContainer wsClient;
@ -172,7 +178,7 @@ public class TextStreamTest
public void testMessageOrdering() throws Exception
{
ClientTextStreamer client = new ClientTextStreamer();
Session session = wsClient.connectToServer(client, server.getWsUri().resolve("/test"));
Session session = wsClient.connectToServer(client, clientConfig, server.getWsUri().resolve("/test"));
final int numLoops = 20;
for (int i = 0; i < numLoops; i++)
@ -194,7 +200,7 @@ public class TextStreamTest
public void testFragmentedMessageOrdering() throws Exception
{
ClientTextStreamer client = new ClientTextStreamer();
Session session = wsClient.connectToServer(client, server.getWsUri().resolve("/test"));
Session session = wsClient.connectToServer(client, clientConfig, server.getWsUri().resolve("/test"));
final int numLoops = 20;
for (int i = 0; i < numLoops; i++)
@ -218,52 +224,68 @@ public class TextStreamTest
@Test
public void testMessageOrderingDoNotReadToEOF() throws Exception
{
ClientTextStreamer client = new ClientTextStreamer();
Session session = wsClient.connectToServer(client, server.getWsUri().resolve("/partial"));
ClientTextStreamer clientEndpoint = new ClientTextStreamer();
Session session = wsClient.connectToServer(clientEndpoint, clientConfig, server.getWsUri().resolve("/partial"));
QueuedTextStreamer serverEndpoint = Objects.requireNonNull(serverEndpoints.poll(5, TimeUnit.SECONDS));
int serverInputBufferSize = 1024;
JavaxWebSocketSession serverSession = (JavaxWebSocketSession)serverEndpoint.session;
serverSession.getCoreSession().setInputBufferSize(serverInputBufferSize);
// Write some initial data.
Writer writer = session.getBasicRemote().getSendWriter();
writer.write("first frame");
writer.flush();
// Signal to stop reading.
writer.write("|");
writer.flush();
// Lots of data after we have stopped reading and onMessage exits.
final String largePayload = StringUtil.stringFrom("x", serverInputBufferSize * 2);
writer.write(largePayload);
writer.close();
final int numLoops = 20;
for (int i = 0; i < numLoops; i++)
{
session.getBasicRemote().sendText(i + "|-----");
}
session.close();
assertTrue(clientEndpoint.closeLatch.await(5, TimeUnit.SECONDS));
assertTrue(serverEndpoint.closeLatch.await(5, TimeUnit.SECONDS));
assertNull(clientEndpoint.error.get());
assertNull(serverEndpoint.error.get());
QueuedTextStreamer queuedTextStreamer = serverEndpoints.poll(5, TimeUnit.SECONDS);
assertNotNull(queuedTextStreamer);
for (int i = 0; i < numLoops; i++)
{
String msg = queuedTextStreamer.messages.poll(5, TimeUnit.SECONDS);
assertThat(msg, Matchers.is(Integer.toString(i)));
}
String msg = serverEndpoint.messages.poll(5, TimeUnit.SECONDS);
assertThat(msg, Matchers.is("first frame"));
}
@ClientEndpoint
public static class ClientTextStreamer
public static class ClientTextStreamer extends WSEndpointTracker implements MessageHandler.Whole<Reader>
{
private final CountDownLatch latch = new CountDownLatch(1);
private final StringBuilder output = new StringBuilder();
@OnMessage
public void echoed(Reader input) throws IOException
@Override
public void onOpen(Session session, EndpointConfig config)
{
while (true)
session.addMessageHandler(this);
super.onOpen(session, config);
}
@Override
public void onMessage(Reader input)
{
try
{
int read = input.read();
if (read < 0)
break;
output.append((char)read);
while (true)
{
int read = input.read();
if (read < 0)
break;
output.append((char)read);
}
latch.countDown();
}
catch (IOException e)
{
throw new RuntimeException(e);
}
latch.countDown();
}
public char[] getEcho()
{
return output.toString().toCharArray();
}
public boolean await(long timeout, TimeUnit unit) throws InterruptedException
{
return latch.await(timeout, unit);
}
}
@ -289,19 +311,16 @@ public class TextStreamTest
}
}
public static class QueuedTextStreamer extends Endpoint implements MessageHandler.Whole<Reader>
public static class QueuedTextStreamer extends WSEndpointTracker implements MessageHandler.Whole<Reader>
{
protected BlockingArrayQueue<String> messages = new BlockingArrayQueue<>();
public QueuedTextStreamer()
{
serverEndpoints.add(this);
}
@Override
public void onOpen(Session session, EndpointConfig config)
{
session.addMessageHandler(this);
super.onOpen(session, config);
serverEndpoints.add(this);
}
@Override

View File

@ -18,6 +18,7 @@
package org.eclipse.jetty.websocket.util.messages;
import java.io.Closeable;
import java.lang.invoke.MethodHandle;
import java.util.concurrent.CompletableFuture;
@ -120,6 +121,10 @@ public abstract class DispatchedMessageSink extends AbstractMessageSink
{
methodHandle.invoke(typeSink);
dispatchComplete.complete(null);
// If the MessageSink can be closed do this to free up resources.
if (typeSink instanceof Closeable)
((Closeable)typeSink).close();
}
catch (Throwable throwable)
{

View File

@ -102,14 +102,14 @@ public class MessageInputStream extends InputStream implements MessageSink
public int read(ByteBuffer buffer) throws IOException
{
Entry result = getCurrentEntry();
Entry currentEntry = getCurrentEntry();
if (LOG.isDebugEnabled())
LOG.debug("result = {}", result);
LOG.debug("currentEntry = {}", currentEntry);
if (result == CLOSED)
if (currentEntry == CLOSED)
throw new IOException("Closed");
if (result == EOF)
if (currentEntry == EOF)
{
if (LOG.isDebugEnabled())
LOG.debug("Read EOF");
@ -117,11 +117,13 @@ public class MessageInputStream extends InputStream implements MessageSink
}
// We have content.
int fillLen = BufferUtil.append(buffer, result.buffer);
if (!result.buffer.hasRemaining())
int fillLen = BufferUtil.append(buffer, currentEntry.buffer);
if (!currentEntry.buffer.hasRemaining())
succeedCurrentEntry();
// Return number of bytes actually copied into buffer.
if (LOG.isDebugEnabled())
LOG.debug("filled {} bytes from {}", fillLen, currentEntry);
return fillLen;
}
@ -131,7 +133,7 @@ public class MessageInputStream extends InputStream implements MessageSink
if (LOG.isDebugEnabled())
LOG.debug("close()");
ArrayList<Entry> failedEntries = new ArrayList<>();
ArrayList<Entry> entries = new ArrayList<>();
synchronized (this)
{
if (closed)
@ -140,20 +142,20 @@ public class MessageInputStream extends InputStream implements MessageSink
if (currentEntry != null)
{
failedEntries.add(currentEntry);
entries.add(currentEntry);
currentEntry = null;
}
// Clear queue and fail all entries.
failedEntries.addAll(buffers);
entries.addAll(buffers);
buffers.clear();
buffers.offer(CLOSED);
}
Throwable cause = new IOException("Closed");
for (Entry e : failedEntries)
// Succeed all entries as we don't need them anymore (failing would close the connection).
for (Entry e : entries)
{
e.callback.failed(cause);
e.callback.succeeded();
}
super.close();