Issue #4538 - simplify MessageInputStream and DispatchedMessageSink

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
Lachlan Roberts 2020-02-14 19:02:02 +11:00
parent 5fe202f29f
commit e2f86f9a19
7 changed files with 322 additions and 220 deletions

View File

@ -20,16 +20,29 @@ package org.eclipse.jetty.websocket.javax.tests.server;
import java.io.IOException;
import java.io.Reader;
import java.io.StringWriter;
import java.io.Writer;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.websocket.ClientEndpoint;
import javax.websocket.ContainerProvider;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.MessageHandler;
import javax.websocket.OnMessage;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import javax.websocket.server.ServerContainer;
import javax.websocket.server.ServerEndpoint;
import javax.websocket.server.ServerEndpointConfig;
import org.eclipse.jetty.util.BlockingArrayQueue;
import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.websocket.core.CloseStatus;
@ -38,29 +51,39 @@ import org.eclipse.jetty.websocket.core.OpCode;
import org.eclipse.jetty.websocket.javax.tests.DataUtils;
import org.eclipse.jetty.websocket.javax.tests.Fuzzer;
import org.eclipse.jetty.websocket.javax.tests.LocalServer;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.hamcrest.Matchers;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertNotNull;
public class TextStreamTest
{
private static final Logger LOG = Log.getLogger(TextStreamTest.class);
private static final BlockingArrayQueue<QueuedTextStreamer> serverEndpoints = new BlockingArrayQueue<>();
private static LocalServer server;
private static ServerContainer container;
private LocalServer server;
private ServerContainer container;
private WebSocketContainer wsClient;
@BeforeAll
public static void startServer() throws Exception
@BeforeEach
public void startServer() throws Exception
{
server = new LocalServer();
server.start();
container = server.getServerContainer();
container.addEndpoint(ServerTextStreamer.class);
container.addEndpoint(ServerEndpointConfig.Builder.create(QueuedTextStreamer.class, "/test").build());
container.addEndpoint(ServerEndpointConfig.Builder.create(QueuedPartialTextStreamer.class, "/partial").build());
wsClient = ContainerProvider.getWebSocketContainer();
}
@AfterAll
public static void stopServer() throws Exception
@AfterEach
public void stopServer() throws Exception
{
server.stop();
}
@ -145,6 +168,105 @@ public class TextStreamTest
}
}
@Test
public void testMessageOrdering() throws Exception
{
ClientTextStreamer client = new ClientTextStreamer();
Session session = wsClient.connectToServer(client, server.getWsUri().resolve("/test"));
final int numLoops = 20;
for (int i = 0; i < numLoops; i++)
{
session.getBasicRemote().sendText(Integer.toString(i));
}
session.close();
QueuedTextStreamer queuedTextStreamer = serverEndpoints.poll(5, TimeUnit.SECONDS);
assertNotNull(queuedTextStreamer);
for (int i = 0; i < numLoops; i++)
{
String msg = queuedTextStreamer.messages.poll(5, TimeUnit.SECONDS);
assertThat(msg, Matchers.is(Integer.toString(i)));
}
}
@Test
public void testFragmentedMessageOrdering() throws Exception
{
ClientTextStreamer client = new ClientTextStreamer();
Session session = wsClient.connectToServer(client, server.getWsUri().resolve("/test"));
final int numLoops = 20;
for (int i = 0; i < numLoops; i++)
{
session.getBasicRemote().sendText("firstFrame" + i, false);
session.getBasicRemote().sendText("|secondFrame" + i, false);
session.getBasicRemote().sendText("|finalFrame" + i, true);
}
session.close();
QueuedTextStreamer queuedTextStreamer = serverEndpoints.poll(5, TimeUnit.SECONDS);
assertNotNull(queuedTextStreamer);
for (int i = 0; i < numLoops; i++)
{
String msg = queuedTextStreamer.messages.poll(5, TimeUnit.SECONDS);
String expected = "firstFrame" + i + "|secondFrame" + i + "|finalFrame" + i;
assertThat(msg, Matchers.is(expected));
}
}
@Test
public void testMessageOrderingDoNotReadToEOF() throws Exception
{
ClientTextStreamer client = new ClientTextStreamer();
Session session = wsClient.connectToServer(client, server.getWsUri().resolve("/partial"));
final int numLoops = 20;
for (int i = 0; i < numLoops; i++)
{
session.getBasicRemote().sendText(i + "|-----");
}
session.close();
QueuedTextStreamer queuedTextStreamer = serverEndpoints.poll(5, TimeUnit.SECONDS);
assertNotNull(queuedTextStreamer);
for (int i = 0; i < numLoops; i++)
{
String msg = queuedTextStreamer.messages.poll(5, TimeUnit.SECONDS);
assertThat(msg, Matchers.is(Integer.toString(i)));
}
}
@ClientEndpoint
public static class ClientTextStreamer
{
private final CountDownLatch latch = new CountDownLatch(1);
private final StringBuilder output = new StringBuilder();
@OnMessage
public void echoed(Reader input) throws IOException
{
while (true)
{
int read = input.read();
if (read < 0)
break;
output.append((char)read);
}
latch.countDown();
}
public char[] getEcho()
{
return output.toString().toCharArray();
}
public boolean await(long timeout, TimeUnit unit) throws InterruptedException
{
return latch.await(timeout, unit);
}
}
@ServerEndpoint("/echo")
public static class ServerTextStreamer
{
@ -166,4 +288,62 @@ public class TextStreamTest
}
}
}
public static class QueuedTextStreamer extends Endpoint implements MessageHandler.Whole<Reader>
{
protected BlockingArrayQueue<String> messages = new BlockingArrayQueue<>();
public QueuedTextStreamer()
{
serverEndpoints.add(this);
}
@Override
public void onOpen(Session session, EndpointConfig config)
{
session.addMessageHandler(this);
}
@Override
public void onMessage(Reader input)
{
try
{
Thread.sleep(Math.abs(new Random().nextLong() % 200));
messages.add(IO.toString(input));
}
catch (Exception e)
{
e.printStackTrace();
}
}
}
public static class QueuedPartialTextStreamer extends QueuedTextStreamer
{
@Override
public void onMessage(Reader input)
{
try
{
Thread.sleep(Math.abs(new Random().nextLong() % 200));
// Do not read to EOF but just the first '|'.
StringWriter writer = new StringWriter();
while (true)
{
int read = input.read();
if (read < 0 || read == '|')
break;
writer.write(read);
}
messages.add(writer.toString());
}
catch (Exception e)
{
e.printStackTrace();
}
}
}
}

View File

@ -1,44 +0,0 @@
//
// ========================================================================
// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under
// the terms of the Eclipse Public License 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0
//
// This Source Code may also be made available under the following
// Secondary Licenses when the conditions for such availability set
// forth in the Eclipse Public License, v. 2.0 are satisfied:
// the Apache License v2.0 which is available at
// https://www.apache.org/licenses/LICENSE-2.0
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.websocket.util.messages;
import java.nio.ByteBuffer;
import java.util.Objects;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
public class CallbackBuffer
{
public ByteBuffer buffer;
public Callback callback;
public CallbackBuffer(Callback callback, ByteBuffer buffer)
{
Objects.requireNonNull(buffer, "buffer");
this.callback = callback;
this.buffer = buffer;
}
@Override
public String toString()
{
return String.format("CallbackBuffer[%s,%s]", BufferUtil.toDetailString(buffer), callback.getClass().getSimpleName());
}
}

View File

@ -22,6 +22,7 @@ import java.lang.invoke.MethodHandle;
import java.util.concurrent.CompletableFuture;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.FutureCallback;
import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.Frame;
@ -93,11 +94,8 @@ import org.eclipse.jetty.websocket.core.Frame;
* EOF stream.read EOF
* RESUME(NEXT MSG)
* </pre>
*
* @param <T> the type of object to give to user function
*/
@SuppressWarnings("Duplicates")
public abstract class DispatchedMessageSink<T> extends AbstractMessageSink
public abstract class DispatchedMessageSink extends AbstractMessageSink
{
private CompletableFuture<Void> dispatchComplete;
private MessageSink typeSink;
@ -114,14 +112,14 @@ public abstract class DispatchedMessageSink<T> extends AbstractMessageSink
if (typeSink == null)
{
typeSink = newSink(frame);
// Dispatch to end user function (will likely start with blocking for data/accept)
dispatchComplete = new CompletableFuture<>();
// Dispatch to end user function (will likely start with blocking for data/accept)
new Thread(() ->
{
final T dispatchedType = (T)typeSink;
try
{
methodHandle.invoke(dispatchedType);
methodHandle.invoke(typeSink);
dispatchComplete.complete(null);
}
catch (Throwable throwable)
@ -131,40 +129,21 @@ public abstract class DispatchedMessageSink<T> extends AbstractMessageSink
}).start();
}
final Callback frameCallback;
Callback frameCallback = callback;
if (frame.isFin())
{
CompletableFuture<Void> finComplete = new CompletableFuture<>();
frameCallback = new Callback()
// This is the final frame we should wait for the frame callback and the dispatched thread.
Callback.Completable completableCallback = new Callback.Completable();
frameCallback = completableCallback;
CompletableFuture.allOf(dispatchComplete, completableCallback).whenComplete((aVoid, throwable) ->
{
@Override
public void failed(Throwable cause)
{
finComplete.completeExceptionally(cause);
}
@Override
public void succeeded()
{
finComplete.complete(null);
}
};
CompletableFuture.allOf(dispatchComplete, finComplete).whenComplete(
(aVoid, throwable) ->
{
typeSink = null;
dispatchComplete = null;
if (throwable != null)
callback.failed(throwable);
else
callback.succeeded();
});
}
else
{
// Non-fin-frame
frameCallback = callback;
typeSink = null;
dispatchComplete = null;
if (throwable != null)
callback.failed(throwable);
else
callback.succeeded();
});
}
typeSink.accept(frame, frameCallback);

View File

@ -18,13 +18,12 @@
package org.eclipse.jetty.websocket.util.messages;
import java.io.InputStream;
import java.lang.invoke.MethodHandle;
import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.Frame;
public class InputStreamMessageSink extends DispatchedMessageSink<InputStream>
public class InputStreamMessageSink extends DispatchedMessageSink
{
public InputStreamMessageSink(CoreSession session, MethodHandle methodHandle)
{

View File

@ -21,10 +21,12 @@ package org.eclipse.jetty.websocket.util.messages;
import java.io.IOException;
import java.io.InputStream;
import java.io.InterruptedIOException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.nio.ByteBuffer;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.eclipse.jetty.util.BlockingArrayQueue;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.log.Log;
@ -40,10 +42,11 @@ import org.eclipse.jetty.websocket.core.Frame;
public class MessageInputStream extends InputStream implements MessageSink
{
private static final Logger LOG = Log.getLogger(MessageInputStream.class);
private static final CallbackBuffer EOF = new CallbackBuffer(Callback.NOOP, BufferUtil.EMPTY_BUFFER);
private final Deque<CallbackBuffer> buffers = new ArrayDeque<>(2);
private static final Entry EOF = new Entry(BufferUtil.EMPTY_BUFFER, Callback.NOOP);
private final BlockingArrayQueue<Entry> buffers = new BlockingArrayQueue<>();
private final AtomicBoolean closed = new AtomicBoolean(false);
private CallbackBuffer activeFrame;
private Entry currentEntry;
private long timeoutMs = -1;
@Override
public void accept(Frame frame, Callback callback)
@ -51,119 +54,20 @@ public class MessageInputStream extends InputStream implements MessageSink
if (LOG.isDebugEnabled())
LOG.debug("accepting {}", frame);
// If closed, we should just toss incoming payloads into the bit bucket.
if (closed.get())
{
callback.failed(new IOException("Already Closed"));
return;
}
if (!frame.hasPayload() && !frame.isFin())
// If closed or we have no payload, request the next frame.
if (closed.get() || (!frame.hasPayload() && !frame.isFin()))
{
callback.succeeded();
return;
}
synchronized (buffers)
{
boolean notify = false;
if (frame.hasPayload())
{
buffers.offer(new CallbackBuffer(callback, frame.getPayload()));
notify = true;
}
else
{
// We cannot wake up blocking read for a zero length frame.
callback.succeeded();
}
if (frame.hasPayload())
buffers.add(new Entry(frame.getPayload(), callback));
else
callback.succeeded();
if (frame.isFin())
{
buffers.offer(EOF);
notify = true;
}
if (notify)
{
// notify other thread
buffers.notify();
}
}
}
@Override
public void close() throws IOException
{
if (LOG.isDebugEnabled())
LOG.debug("close()");
if (closed.compareAndSet(false, true))
{
synchronized (buffers)
{
buffers.offer(EOF);
buffers.notify();
}
}
super.close();
}
public CallbackBuffer getActiveFrame() throws InterruptedIOException
{
if (activeFrame == null)
{
// sync and poll queue
CallbackBuffer result;
synchronized (buffers)
{
try
{
while ((result = buffers.poll()) == null)
{
// TODO: handle read timeout here?
buffers.wait();
}
}
catch (InterruptedException e)
{
shutdown();
throw new InterruptedIOException();
}
}
activeFrame = result;
}
return activeFrame;
}
private void shutdown()
{
if (LOG.isDebugEnabled())
LOG.debug("shutdown()");
synchronized (buffers)
{
closed.set(true);
Throwable cause = new IOException("Shutdown");
for (CallbackBuffer buffer : buffers)
{
buffer.callback.failed(cause);
}
// Removed buffers that may have remained in the queue.
buffers.clear();
}
}
@Override
public void mark(int readlimit)
{
// Not supported.
}
@Override
public boolean markSupported()
{
return false;
if (frame.isFin())
buffers.add(EOF);
}
@Override
@ -185,14 +89,9 @@ public class MessageInputStream extends InputStream implements MessageSink
public int read(final byte[] b, final int off, final int len) throws IOException
{
if (closed.get())
{
if (LOG.isDebugEnabled())
LOG.debug("Stream closed");
return -1;
}
CallbackBuffer result = getActiveFrame();
Entry result = getCurrentEntry();
if (LOG.isDebugEnabled())
LOG.debug("result = {}", result);
@ -207,10 +106,9 @@ public class MessageInputStream extends InputStream implements MessageSink
// We have content
int fillLen = Math.min(result.buffer.remaining(), len);
result.buffer.get(b, off, fillLen);
if (!result.buffer.hasRemaining())
{
activeFrame = null;
currentEntry = null;
result.callback.succeeded();
}
@ -219,8 +117,94 @@ public class MessageInputStream extends InputStream implements MessageSink
}
@Override
public void reset() throws IOException
public void close() throws IOException
{
throw new IOException("reset() not supported");
if (LOG.isDebugEnabled())
LOG.debug("close()");
if (closed.compareAndSet(false, true))
{
synchronized (buffers)
{
buffers.offer(EOF);
buffers.notify();
}
}
super.close();
}
private void shutdown()
{
if (LOG.isDebugEnabled())
LOG.debug("shutdown()");
synchronized (this)
{
closed.set(true);
Throwable cause = new IOException("Shutdown");
for (Entry buffer : buffers)
{
buffer.callback.failed(cause);
}
// Removed buffers that may have remained in the queue.
buffers.clear();
}
}
public void setTimeout(long timeoutMs)
{
this.timeoutMs = timeoutMs;
}
private Entry getCurrentEntry() throws IOException
{
if (currentEntry != null)
return currentEntry;
// sync and poll queue
try
{
if (LOG.isDebugEnabled())
LOG.debug("Waiting {} ms to read", timeoutMs);
if (timeoutMs < 0)
{
// Wait forever until a buffer is available.
currentEntry = buffers.take();
}
else
{
// Wait at most for the given timeout.
currentEntry = buffers.poll(timeoutMs, TimeUnit.MILLISECONDS);
if (currentEntry == null)
throw new IOException(String.format("Read timeout: %,dms expired", timeoutMs));
}
}
catch (InterruptedException e)
{
shutdown();
throw new InterruptedIOException();
}
return currentEntry;
}
private static class Entry
{
public ByteBuffer buffer;
public Callback callback;
public Entry(ByteBuffer buffer, Callback callback)
{
this.buffer = Objects.requireNonNull(buffer);
this.callback = callback;
}
@Override
public String toString()
{
return String.format("Entry[%s,%s]", BufferUtil.toDetailString(buffer), callback.getClass().getSimpleName());
}
}
}

View File

@ -33,10 +33,15 @@ public class MessageReader extends InputStreamReader implements MessageSink
{
private final MessageInputStream stream;
public MessageReader(MessageInputStream stream)
public MessageReader()
{
super(stream, StandardCharsets.UTF_8);
this.stream = stream;
this(new MessageInputStream());
}
private MessageReader(MessageInputStream inputStream)
{
super(inputStream, StandardCharsets.UTF_8);
this.stream = inputStream;
}
@Override

View File

@ -18,13 +18,12 @@
package org.eclipse.jetty.websocket.util.messages;
import java.io.Reader;
import java.lang.invoke.MethodHandle;
import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.Frame;
public class ReaderMessageSink extends DispatchedMessageSink<Reader>
public class ReaderMessageSink extends DispatchedMessageSink
{
public ReaderMessageSink(CoreSession session, MethodHandle methodHandle)
{
@ -34,6 +33,6 @@ public class ReaderMessageSink extends DispatchedMessageSink<Reader>
@Override
public MessageReader newSink(Frame frame)
{
return new MessageReader(new MessageInputStream());
return new MessageReader();
}
}