Issue #4538 - allow MessageInputStream not to read to EOF
Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
parent
97abed549b
commit
5c839d791d
|
@ -438,6 +438,22 @@ public class StringUtil
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a string from another string repeated n times.
|
||||
*
|
||||
* @param s the string to use
|
||||
* @param n the number of times this string should be appended
|
||||
*/
|
||||
public static String stringFrom(String s, int n)
|
||||
{
|
||||
StringBuilder stringBuilder = new StringBuilder(s.length() * n);
|
||||
for (int i = 0; i < n; i++)
|
||||
{
|
||||
stringBuilder.append(s);
|
||||
}
|
||||
return stringBuilder.toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a non null string.
|
||||
*
|
||||
|
|
|
@ -26,12 +26,12 @@ import java.nio.ByteBuffer;
|
|||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Random;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import javax.websocket.ClientEndpoint;
|
||||
import javax.websocket.ClientEndpointConfig;
|
||||
import javax.websocket.ContainerProvider;
|
||||
import javax.websocket.Endpoint;
|
||||
import javax.websocket.EndpointConfig;
|
||||
import javax.websocket.MessageHandler;
|
||||
import javax.websocket.OnMessage;
|
||||
|
@ -43,14 +43,17 @@ import javax.websocket.server.ServerEndpointConfig;
|
|||
|
||||
import org.eclipse.jetty.util.BlockingArrayQueue;
|
||||
import org.eclipse.jetty.util.IO;
|
||||
import org.eclipse.jetty.util.StringUtil;
|
||||
import org.eclipse.jetty.util.log.Log;
|
||||
import org.eclipse.jetty.util.log.Logger;
|
||||
import org.eclipse.jetty.websocket.core.CloseStatus;
|
||||
import org.eclipse.jetty.websocket.core.Frame;
|
||||
import org.eclipse.jetty.websocket.core.OpCode;
|
||||
import org.eclipse.jetty.websocket.javax.common.JavaxWebSocketSession;
|
||||
import org.eclipse.jetty.websocket.javax.tests.DataUtils;
|
||||
import org.eclipse.jetty.websocket.javax.tests.Fuzzer;
|
||||
import org.eclipse.jetty.websocket.javax.tests.LocalServer;
|
||||
import org.eclipse.jetty.websocket.javax.tests.WSEndpointTracker;
|
||||
import org.hamcrest.Matchers;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
|
@ -59,12 +62,15 @@ import org.junit.jupiter.api.Test;
|
|||
|
||||
import static org.hamcrest.MatcherAssert.assertThat;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
public class TextStreamTest
|
||||
{
|
||||
private static final Logger LOG = Log.getLogger(TextStreamTest.class);
|
||||
private static final BlockingArrayQueue<QueuedTextStreamer> serverEndpoints = new BlockingArrayQueue<>();
|
||||
|
||||
private final ClientEndpointConfig clientConfig = ClientEndpointConfig.Builder.create().build();
|
||||
private LocalServer server;
|
||||
private ServerContainer container;
|
||||
private WebSocketContainer wsClient;
|
||||
|
@ -172,7 +178,7 @@ public class TextStreamTest
|
|||
public void testMessageOrdering() throws Exception
|
||||
{
|
||||
ClientTextStreamer client = new ClientTextStreamer();
|
||||
Session session = wsClient.connectToServer(client, server.getWsUri().resolve("/test"));
|
||||
Session session = wsClient.connectToServer(client, clientConfig, server.getWsUri().resolve("/test"));
|
||||
|
||||
final int numLoops = 20;
|
||||
for (int i = 0; i < numLoops; i++)
|
||||
|
@ -194,7 +200,7 @@ public class TextStreamTest
|
|||
public void testFragmentedMessageOrdering() throws Exception
|
||||
{
|
||||
ClientTextStreamer client = new ClientTextStreamer();
|
||||
Session session = wsClient.connectToServer(client, server.getWsUri().resolve("/test"));
|
||||
Session session = wsClient.connectToServer(client, clientConfig, server.getWsUri().resolve("/test"));
|
||||
|
||||
final int numLoops = 20;
|
||||
for (int i = 0; i < numLoops; i++)
|
||||
|
@ -218,52 +224,68 @@ public class TextStreamTest
|
|||
@Test
|
||||
public void testMessageOrderingDoNotReadToEOF() throws Exception
|
||||
{
|
||||
ClientTextStreamer client = new ClientTextStreamer();
|
||||
Session session = wsClient.connectToServer(client, server.getWsUri().resolve("/partial"));
|
||||
ClientTextStreamer clientEndpoint = new ClientTextStreamer();
|
||||
Session session = wsClient.connectToServer(clientEndpoint, clientConfig, server.getWsUri().resolve("/partial"));
|
||||
QueuedTextStreamer serverEndpoint = Objects.requireNonNull(serverEndpoints.poll(5, TimeUnit.SECONDS));
|
||||
|
||||
int serverInputBufferSize = 1024;
|
||||
JavaxWebSocketSession serverSession = (JavaxWebSocketSession)serverEndpoint.session;
|
||||
serverSession.getCoreSession().setInputBufferSize(serverInputBufferSize);
|
||||
|
||||
// Write some initial data.
|
||||
Writer writer = session.getBasicRemote().getSendWriter();
|
||||
writer.write("first frame");
|
||||
writer.flush();
|
||||
|
||||
// Signal to stop reading.
|
||||
writer.write("|");
|
||||
writer.flush();
|
||||
|
||||
// Lots of data after we have stopped reading and onMessage exits.
|
||||
final String largePayload = StringUtil.stringFrom("x", serverInputBufferSize * 2);
|
||||
writer.write(largePayload);
|
||||
writer.close();
|
||||
|
||||
final int numLoops = 20;
|
||||
for (int i = 0; i < numLoops; i++)
|
||||
{
|
||||
session.getBasicRemote().sendText(i + "|-----");
|
||||
}
|
||||
session.close();
|
||||
assertTrue(clientEndpoint.closeLatch.await(5, TimeUnit.SECONDS));
|
||||
assertTrue(serverEndpoint.closeLatch.await(5, TimeUnit.SECONDS));
|
||||
assertNull(clientEndpoint.error.get());
|
||||
assertNull(serverEndpoint.error.get());
|
||||
|
||||
QueuedTextStreamer queuedTextStreamer = serverEndpoints.poll(5, TimeUnit.SECONDS);
|
||||
assertNotNull(queuedTextStreamer);
|
||||
for (int i = 0; i < numLoops; i++)
|
||||
{
|
||||
String msg = queuedTextStreamer.messages.poll(5, TimeUnit.SECONDS);
|
||||
assertThat(msg, Matchers.is(Integer.toString(i)));
|
||||
}
|
||||
String msg = serverEndpoint.messages.poll(5, TimeUnit.SECONDS);
|
||||
assertThat(msg, Matchers.is("first frame"));
|
||||
}
|
||||
|
||||
@ClientEndpoint
|
||||
public static class ClientTextStreamer
|
||||
public static class ClientTextStreamer extends WSEndpointTracker implements MessageHandler.Whole<Reader>
|
||||
{
|
||||
private final CountDownLatch latch = new CountDownLatch(1);
|
||||
private final StringBuilder output = new StringBuilder();
|
||||
|
||||
@OnMessage
|
||||
public void echoed(Reader input) throws IOException
|
||||
@Override
|
||||
public void onOpen(Session session, EndpointConfig config)
|
||||
{
|
||||
while (true)
|
||||
session.addMessageHandler(this);
|
||||
super.onOpen(session, config);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onMessage(Reader input)
|
||||
{
|
||||
try
|
||||
{
|
||||
int read = input.read();
|
||||
if (read < 0)
|
||||
break;
|
||||
output.append((char)read);
|
||||
while (true)
|
||||
{
|
||||
int read = input.read();
|
||||
if (read < 0)
|
||||
break;
|
||||
output.append((char)read);
|
||||
}
|
||||
latch.countDown();
|
||||
}
|
||||
catch (IOException e)
|
||||
{
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
latch.countDown();
|
||||
}
|
||||
|
||||
public char[] getEcho()
|
||||
{
|
||||
return output.toString().toCharArray();
|
||||
}
|
||||
|
||||
public boolean await(long timeout, TimeUnit unit) throws InterruptedException
|
||||
{
|
||||
return latch.await(timeout, unit);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -289,19 +311,16 @@ public class TextStreamTest
|
|||
}
|
||||
}
|
||||
|
||||
public static class QueuedTextStreamer extends Endpoint implements MessageHandler.Whole<Reader>
|
||||
public static class QueuedTextStreamer extends WSEndpointTracker implements MessageHandler.Whole<Reader>
|
||||
{
|
||||
protected BlockingArrayQueue<String> messages = new BlockingArrayQueue<>();
|
||||
|
||||
public QueuedTextStreamer()
|
||||
{
|
||||
serverEndpoints.add(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onOpen(Session session, EndpointConfig config)
|
||||
{
|
||||
session.addMessageHandler(this);
|
||||
super.onOpen(session, config);
|
||||
serverEndpoints.add(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
package org.eclipse.jetty.websocket.util.messages;
|
||||
|
||||
import java.io.Closeable;
|
||||
import java.lang.invoke.MethodHandle;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
||||
|
@ -120,6 +121,10 @@ public abstract class DispatchedMessageSink extends AbstractMessageSink
|
|||
{
|
||||
methodHandle.invoke(typeSink);
|
||||
dispatchComplete.complete(null);
|
||||
|
||||
// If the MessageSink can be closed do this to free up resources.
|
||||
if (typeSink instanceof Closeable)
|
||||
((Closeable)typeSink).close();
|
||||
}
|
||||
catch (Throwable throwable)
|
||||
{
|
||||
|
|
|
@ -102,14 +102,14 @@ public class MessageInputStream extends InputStream implements MessageSink
|
|||
|
||||
public int read(ByteBuffer buffer) throws IOException
|
||||
{
|
||||
Entry result = getCurrentEntry();
|
||||
Entry currentEntry = getCurrentEntry();
|
||||
if (LOG.isDebugEnabled())
|
||||
LOG.debug("result = {}", result);
|
||||
LOG.debug("currentEntry = {}", currentEntry);
|
||||
|
||||
if (result == CLOSED)
|
||||
if (currentEntry == CLOSED)
|
||||
throw new IOException("Closed");
|
||||
|
||||
if (result == EOF)
|
||||
if (currentEntry == EOF)
|
||||
{
|
||||
if (LOG.isDebugEnabled())
|
||||
LOG.debug("Read EOF");
|
||||
|
@ -117,11 +117,13 @@ public class MessageInputStream extends InputStream implements MessageSink
|
|||
}
|
||||
|
||||
// We have content.
|
||||
int fillLen = BufferUtil.append(buffer, result.buffer);
|
||||
if (!result.buffer.hasRemaining())
|
||||
int fillLen = BufferUtil.append(buffer, currentEntry.buffer);
|
||||
if (!currentEntry.buffer.hasRemaining())
|
||||
succeedCurrentEntry();
|
||||
|
||||
// Return number of bytes actually copied into buffer.
|
||||
if (LOG.isDebugEnabled())
|
||||
LOG.debug("filled {} bytes from {}", fillLen, currentEntry);
|
||||
return fillLen;
|
||||
}
|
||||
|
||||
|
@ -131,7 +133,7 @@ public class MessageInputStream extends InputStream implements MessageSink
|
|||
if (LOG.isDebugEnabled())
|
||||
LOG.debug("close()");
|
||||
|
||||
ArrayList<Entry> failedEntries = new ArrayList<>();
|
||||
ArrayList<Entry> entries = new ArrayList<>();
|
||||
synchronized (this)
|
||||
{
|
||||
if (closed)
|
||||
|
@ -140,20 +142,20 @@ public class MessageInputStream extends InputStream implements MessageSink
|
|||
|
||||
if (currentEntry != null)
|
||||
{
|
||||
failedEntries.add(currentEntry);
|
||||
entries.add(currentEntry);
|
||||
currentEntry = null;
|
||||
}
|
||||
|
||||
// Clear queue and fail all entries.
|
||||
failedEntries.addAll(buffers);
|
||||
entries.addAll(buffers);
|
||||
buffers.clear();
|
||||
buffers.offer(CLOSED);
|
||||
}
|
||||
|
||||
Throwable cause = new IOException("Closed");
|
||||
for (Entry e : failedEntries)
|
||||
// Succeed all entries as we don't need them anymore (failing would close the connection).
|
||||
for (Entry e : entries)
|
||||
{
|
||||
e.callback.failed(cause);
|
||||
e.callback.succeeded();
|
||||
}
|
||||
|
||||
super.close();
|
||||
|
|
Loading…
Reference in New Issue