Merge pull request #4588 from eclipse/jetty-10.0.x-4538-MessageReaderWriter

Issue #4538 - rework of websocket message reader and writers
This commit is contained in:
Lachlan 2020-03-11 15:47:12 +11:00 committed by GitHub
commit b1d30fcd6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 713 additions and 519 deletions

View File

@ -438,6 +438,22 @@ public class StringUtil
}
}
/**
* Generate a string from another string repeated n times.
*
* @param s the string to use
* @param n the number of times this string should be appended
*/
public static String stringFrom(String s, int n)
{
StringBuilder stringBuilder = new StringBuilder(s.length() * n);
for (int i = 0; i < n; i++)
{
stringBuilder.append(s);
}
return stringBuilder.toString();
}
/**
* Return a non null string.
*

View File

@ -21,7 +21,6 @@ package org.eclipse.jetty.websocket.javax.tests.client;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.Reader;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.BlockingQueue;
@ -37,15 +36,13 @@ import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.websocket.core.FrameHandler;
import org.eclipse.jetty.websocket.core.MessageHandler;
import org.eclipse.jetty.websocket.core.server.Negotiation;
import org.eclipse.jetty.websocket.core.server.WebSocketNegotiator;
import org.eclipse.jetty.websocket.javax.tests.CoreServer;
import org.eclipse.jetty.websocket.javax.tests.WSEventTracker;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInfo;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
@ -58,21 +55,14 @@ public class DecoderReaderManySmallTest
@BeforeEach
public void setUp() throws Exception
{
server = new CoreServer(new CoreServer.BaseNegotiator()
server = new CoreServer(WebSocketNegotiator.from((negotiation) ->
{
@Override
public FrameHandler negotiate(Negotiation negotiation) throws IOException
{
List<String> offeredSubProtocols = negotiation.getOfferedSubprotocols();
List<String> offeredSubProtocols = negotiation.getOfferedSubprotocols();
if (!offeredSubProtocols.isEmpty())
negotiation.setSubprotocol(offeredSubProtocols.get(0));
if (!offeredSubProtocols.isEmpty())
{
negotiation.setSubprotocol(offeredSubProtocols.get(0));
}
return new EventIdFrameHandler();
}
});
return new EventIdFrameHandler();
}));
server.start();
client = ContainerProvider.getWebSocketContainer();
@ -86,15 +76,13 @@ public class DecoderReaderManySmallTest
}
@Test
public void testManyIds(TestInfo testInfo) throws Exception
public void testManyIds() throws Exception
{
URI wsUri = server.getWsUri().resolve("/eventids");
EventIdSocket clientSocket = new EventIdSocket(testInfo.getTestMethod().toString());
final int from = 1000;
final int to = 2000;
try (Session clientSession = client.connectToServer(clientSocket, wsUri))
EventIdSocket clientSocket = new EventIdSocket();
try (Session clientSession = client.connectToServer(clientSocket, server.getWsUri()))
{
clientSession.getAsyncRemote().sendText("seq|" + from + "|" + to);
}
@ -154,12 +142,6 @@ public class DecoderReaderManySmallTest
{
public BlockingQueue<EventId> messageQueue = new LinkedBlockingDeque<>();
public EventIdSocket(String id)
{
super(id);
}
@SuppressWarnings("unused")
@OnMessage
public void onMessage(EventId msg)
{

View File

@ -20,47 +20,76 @@ package org.eclipse.jetty.websocket.javax.tests.server;
import java.io.IOException;
import java.io.Reader;
import java.io.StringWriter;
import java.io.Writer;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.websocket.ClientEndpointConfig;
import javax.websocket.ContainerProvider;
import javax.websocket.EndpointConfig;
import javax.websocket.MessageHandler;
import javax.websocket.OnMessage;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import javax.websocket.server.ServerContainer;
import javax.websocket.server.ServerEndpoint;
import javax.websocket.server.ServerEndpointConfig;
import org.eclipse.jetty.util.BlockingArrayQueue;
import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.websocket.core.CloseStatus;
import org.eclipse.jetty.websocket.core.Frame;
import org.eclipse.jetty.websocket.core.OpCode;
import org.eclipse.jetty.websocket.javax.common.JavaxWebSocketSession;
import org.eclipse.jetty.websocket.javax.tests.DataUtils;
import org.eclipse.jetty.websocket.javax.tests.Fuzzer;
import org.eclipse.jetty.websocket.javax.tests.LocalServer;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.eclipse.jetty.websocket.javax.tests.WSEndpointTracker;
import org.hamcrest.Matchers;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TextStreamTest
{
private static final Logger LOG = Log.getLogger(TextStreamTest.class);
private static final BlockingArrayQueue<QueuedTextStreamer> serverEndpoints = new BlockingArrayQueue<>();
private static LocalServer server;
private static ServerContainer container;
private final ClientEndpointConfig clientConfig = ClientEndpointConfig.Builder.create().build();
private LocalServer server;
private ServerContainer container;
private WebSocketContainer wsClient;
@BeforeAll
public static void startServer() throws Exception
@BeforeEach
public void startServer() throws Exception
{
server = new LocalServer();
server.start();
container = server.getServerContainer();
container.addEndpoint(ServerTextStreamer.class);
container.addEndpoint(ServerEndpointConfig.Builder.create(QueuedTextStreamer.class, "/test").build());
container.addEndpoint(ServerEndpointConfig.Builder.create(QueuedPartialTextStreamer.class, "/partial").build());
wsClient = ContainerProvider.getWebSocketContainer();
}
@AfterAll
public static void stopServer() throws Exception
@AfterEach
public void stopServer() throws Exception
{
server.stop();
}
@ -145,6 +174,121 @@ public class TextStreamTest
}
}
@Test
public void testMessageOrdering() throws Exception
{
ClientTextStreamer client = new ClientTextStreamer();
Session session = wsClient.connectToServer(client, clientConfig, server.getWsUri().resolve("/test"));
final int numLoops = 20;
for (int i = 0; i < numLoops; i++)
{
session.getBasicRemote().sendText(Integer.toString(i));
}
session.close();
QueuedTextStreamer queuedTextStreamer = serverEndpoints.poll(5, TimeUnit.SECONDS);
assertNotNull(queuedTextStreamer);
for (int i = 0; i < numLoops; i++)
{
String msg = queuedTextStreamer.messages.poll(5, TimeUnit.SECONDS);
assertThat(msg, Matchers.is(Integer.toString(i)));
}
}
@Test
public void testFragmentedMessageOrdering() throws Exception
{
ClientTextStreamer client = new ClientTextStreamer();
Session session = wsClient.connectToServer(client, clientConfig, server.getWsUri().resolve("/test"));
final int numLoops = 20;
for (int i = 0; i < numLoops; i++)
{
session.getBasicRemote().sendText("firstFrame" + i, false);
session.getBasicRemote().sendText("|secondFrame" + i, false);
session.getBasicRemote().sendText("|finalFrame" + i, true);
}
session.close();
QueuedTextStreamer queuedTextStreamer = serverEndpoints.poll(5, TimeUnit.SECONDS);
assertNotNull(queuedTextStreamer);
for (int i = 0; i < numLoops; i++)
{
String msg = queuedTextStreamer.messages.poll(5, TimeUnit.SECONDS);
String expected = "firstFrame" + i + "|secondFrame" + i + "|finalFrame" + i;
assertThat(msg, Matchers.is(expected));
}
}
@Test
public void testMessageOrderingDoNotReadToEOF() throws Exception
{
ClientTextStreamer clientEndpoint = new ClientTextStreamer();
Session session = wsClient.connectToServer(clientEndpoint, clientConfig, server.getWsUri().resolve("/partial"));
QueuedTextStreamer serverEndpoint = Objects.requireNonNull(serverEndpoints.poll(5, TimeUnit.SECONDS));
int serverInputBufferSize = 1024;
JavaxWebSocketSession serverSession = (JavaxWebSocketSession)serverEndpoint.session;
serverSession.getCoreSession().setInputBufferSize(serverInputBufferSize);
// Write some initial data.
Writer writer = session.getBasicRemote().getSendWriter();
writer.write("first frame");
writer.flush();
// Signal to stop reading.
writer.write("|");
writer.flush();
// Lots of data after we have stopped reading and onMessage exits.
final String largePayload = StringUtil.stringFrom("x", serverInputBufferSize * 2);
writer.write(largePayload);
writer.close();
session.close();
assertTrue(clientEndpoint.closeLatch.await(5, TimeUnit.SECONDS));
assertTrue(serverEndpoint.closeLatch.await(5, TimeUnit.SECONDS));
assertNull(clientEndpoint.error.get());
assertNull(serverEndpoint.error.get());
String msg = serverEndpoint.messages.poll(5, TimeUnit.SECONDS);
assertThat(msg, Matchers.is("first frame"));
}
public static class ClientTextStreamer extends WSEndpointTracker implements MessageHandler.Whole<Reader>
{
private final CountDownLatch latch = new CountDownLatch(1);
private final StringBuilder output = new StringBuilder();
@Override
public void onOpen(Session session, EndpointConfig config)
{
session.addMessageHandler(this);
super.onOpen(session, config);
}
@Override
public void onMessage(Reader input)
{
try
{
while (true)
{
int read = input.read();
if (read < 0)
break;
output.append((char)read);
}
latch.countDown();
}
catch (IOException e)
{
throw new RuntimeException(e);
}
}
}
@ServerEndpoint("/echo")
public static class ServerTextStreamer
{
@ -166,4 +310,59 @@ public class TextStreamTest
}
}
}
public static class QueuedTextStreamer extends WSEndpointTracker implements MessageHandler.Whole<Reader>
{
protected BlockingArrayQueue<String> messages = new BlockingArrayQueue<>();
@Override
public void onOpen(Session session, EndpointConfig config)
{
session.addMessageHandler(this);
super.onOpen(session, config);
serverEndpoints.add(this);
}
@Override
public void onMessage(Reader input)
{
try
{
Thread.sleep(Math.abs(new Random().nextLong() % 200));
messages.add(IO.toString(input));
}
catch (Exception e)
{
e.printStackTrace();
}
}
}
public static class QueuedPartialTextStreamer extends QueuedTextStreamer
{
@Override
public void onMessage(Reader input)
{
try
{
Thread.sleep(Math.abs(new Random().nextLong() % 200));
// Do not read to EOF but just the first '|'.
StringWriter writer = new StringWriter();
while (true)
{
int read = input.read();
if (read < 0 || read == '|')
break;
writer.write(read);
}
messages.add(writer.toString());
}
catch (Exception e)
{
e.printStackTrace();
}
}
}
}

View File

@ -26,7 +26,9 @@ import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.FutureCallback;
import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.websocket.core.Frame;
import org.eclipse.jetty.websocket.core.OpCode;
@ -36,6 +38,7 @@ import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTimeout;
public class MessageInputStreamTest
@ -166,7 +169,7 @@ public class MessageInputStreamTest
{
// wait for a little bit before sending input closed
TimeUnit.MILLISECONDS.sleep(400);
stream.close();
stream.accept(new Frame(OpCode.TEXT, true, BufferUtil.EMPTY_BUFFER), Callback.NOOP);
}
catch (Throwable t)
{
@ -177,11 +180,22 @@ public class MessageInputStreamTest
// Read byte from stream.
int b = stream.read();
// Should be a -1, indicating the end of the stream.
// Test it
// Should be a -1, indicating the end of the stream.
assertThat("Error when closing", hadError.get(), is(false));
assertThat("Initial byte (Should be EOF)", b, is(-1));
// Close the stream.
stream.close();
// Any frame content after stream is closed should be discarded, and the callback succeeded.
FutureCallback callback = new FutureCallback();
stream.accept(new Frame(OpCode.TEXT, true, BufferUtil.toBuffer("hello world")), callback);
callback.block(5, TimeUnit.SECONDS);
// Any read after the stream is closed leads to an IOException.
IOException error = assertThrows(IOException.class, stream::read);
assertThat(error.getMessage(), is("Closed"));
}
});
}

View File

@ -1,103 +0,0 @@
//
// ========================================================================
// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under
// the terms of the Eclipse Public License 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0
//
// This Source Code may also be made available under the following
// Secondary Licenses when the conditions for such availability set
// forth in the Eclipse Public License, v. 2.0 are satisfied:
// the Apache License v2.0 which is available at
// https://www.apache.org/licenses/LICENSE-2.0
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.websocket.common;
import java.util.Arrays;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.websocket.util.messages.MessageWriter;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
public class MessageWriterTest
{
private static final Logger LOG = Log.getLogger(MessageWriterTest.class);
private static final int OUTPUT_BUFFER_SIZE = 4096;
public TestableLeakTrackingBufferPool bufferPool = new TestableLeakTrackingBufferPool("Test");
@AfterEach
public void afterEach()
{
bufferPool.assertNoLeaks();
}
private OutgoingMessageCapture remoteSocket;
@BeforeEach
public void setupSession()
{
remoteSocket = new OutgoingMessageCapture();
remoteSocket.setOutputBufferSize(OUTPUT_BUFFER_SIZE);
}
@Test
public void testMultipleWrites() throws Exception
{
try (MessageWriter stream = new MessageWriter(remoteSocket, bufferPool))
{
stream.write("Hello");
stream.write(" ");
stream.write("World");
}
assertThat("Socket.messageQueue.size", remoteSocket.textMessages.size(), is(1));
String msg = remoteSocket.textMessages.poll();
assertThat("Message", msg, is("Hello World"));
}
@Test
public void testSingleWrite() throws Exception
{
try (MessageWriter stream = new MessageWriter(remoteSocket, bufferPool))
{
stream.append("Hello World");
}
assertThat("Socket.messageQueue.size", remoteSocket.textMessages.size(), is(1));
String msg = remoteSocket.textMessages.poll();
assertThat("Message", msg, is("Hello World"));
}
@Test
public void testWriteLargeRequiringMultipleBuffers() throws Exception
{
int size = (int)(OUTPUT_BUFFER_SIZE * 2.5);
char[] buf = new char[size];
if (LOG.isDebugEnabled())
LOG.debug("Buffer size: {}", size);
Arrays.fill(buf, 'x');
buf[size - 1] = 'o'; // mark last entry for debugging
try (MessageWriter stream = new MessageWriter(remoteSocket, bufferPool))
{
stream.write(buf);
}
assertThat("Socket.messageQueue.size", remoteSocket.textMessages.size(), is(1));
String msg = remoteSocket.textMessages.poll();
String expected = new String(buf);
assertThat("Message", msg, is(expected));
}
}

View File

@ -121,7 +121,7 @@ public class OutgoingMessageCapture extends CoreSession.Empty implements CoreSes
if (OpCode.isDataFrame(frame.getOpCode()))
{
messageSink.accept(frame, callback);
messageSink.accept(Frame.copy(frame), callback);
if (frame.isFin())
{
messageSink = null;

View File

@ -1,44 +0,0 @@
//
// ========================================================================
// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under
// the terms of the Eclipse Public License 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0
//
// This Source Code may also be made available under the following
// Secondary Licenses when the conditions for such availability set
// forth in the Eclipse Public License, v. 2.0 are satisfied:
// the Apache License v2.0 which is available at
// https://www.apache.org/licenses/LICENSE-2.0
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.websocket.util.messages;
import java.nio.ByteBuffer;
import java.util.Objects;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
public class CallbackBuffer
{
public ByteBuffer buffer;
public Callback callback;
public CallbackBuffer(Callback callback, ByteBuffer buffer)
{
Objects.requireNonNull(buffer, "buffer");
this.callback = callback;
this.buffer = buffer;
}
@Override
public String toString()
{
return String.format("CallbackBuffer[%s,%s]", BufferUtil.toDetailString(buffer), callback.getClass().getSimpleName());
}
}

View File

@ -18,10 +18,12 @@
package org.eclipse.jetty.websocket.util.messages;
import java.io.Closeable;
import java.lang.invoke.MethodHandle;
import java.util.concurrent.CompletableFuture;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.Frame;
@ -93,11 +95,8 @@ import org.eclipse.jetty.websocket.core.Frame;
* EOF stream.read EOF
* RESUME(NEXT MSG)
* </pre>
*
* @param <T> the type of object to give to user function
*/
@SuppressWarnings("Duplicates")
public abstract class DispatchedMessageSink<T> extends AbstractMessageSink
public abstract class DispatchedMessageSink extends AbstractMessageSink
{
private CompletableFuture<Void> dispatchComplete;
private MessageSink typeSink;
@ -114,44 +113,45 @@ public abstract class DispatchedMessageSink<T> extends AbstractMessageSink
if (typeSink == null)
{
typeSink = newSink(frame);
// Dispatch to end user function (will likely start with blocking for data/accept)
dispatchComplete = new CompletableFuture<>();
// Dispatch to end user function (will likely start with blocking for data/accept).
// If the MessageSink can be closed do this after invoking and before completing the CompletableFuture.
new Thread(() ->
{
final T dispatchedType = (T)typeSink;
try
{
methodHandle.invoke(dispatchedType);
methodHandle.invoke(typeSink);
if (typeSink instanceof Closeable)
IO.close((Closeable)typeSink);
dispatchComplete.complete(null);
}
catch (Throwable throwable)
{
if (typeSink instanceof Closeable)
IO.close((Closeable)typeSink);
dispatchComplete.completeExceptionally(throwable);
}
}).start();
}
final Callback frameCallback;
Callback frameCallback = callback;
if (frame.isFin())
{
CompletableFuture<Void> finComplete = new CompletableFuture<>();
frameCallback = Callback.from(() -> finComplete.complete(null), finComplete::completeExceptionally);
CompletableFuture.allOf(dispatchComplete, finComplete).whenComplete(
(aVoid, throwable) ->
{
typeSink = null;
dispatchComplete = null;
if (throwable != null)
callback.failed(throwable);
else
callback.succeeded();
});
}
else
{
// Non-fin-frame
frameCallback = callback;
// This is the final frame we should wait for the frame callback and the dispatched thread.
Callback.Completable completableCallback = new Callback.Completable();
frameCallback = completableCallback;
CompletableFuture.allOf(dispatchComplete, completableCallback).whenComplete((aVoid, throwable) ->
{
typeSink = null;
dispatchComplete = null;
if (throwable != null)
callback.failed(throwable);
else
callback.succeeded();
});
}
typeSink.accept(frame, frameCallback);

View File

@ -18,13 +18,12 @@
package org.eclipse.jetty.websocket.util.messages;
import java.io.InputStream;
import java.lang.invoke.MethodHandle;
import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.Frame;
public class InputStreamMessageSink extends DispatchedMessageSink<InputStream>
public class InputStreamMessageSink extends DispatchedMessageSink
{
public InputStreamMessageSink(CoreSession session, MethodHandle methodHandle)
{

View File

@ -21,10 +21,12 @@ package org.eclipse.jetty.websocket.util.messages;
import java.io.IOException;
import java.io.InputStream;
import java.io.InterruptedIOException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.concurrent.atomic.AtomicBoolean;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.util.BlockingArrayQueue;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.log.Log;
@ -40,10 +42,12 @@ import org.eclipse.jetty.websocket.core.Frame;
public class MessageInputStream extends InputStream implements MessageSink
{
private static final Logger LOG = Log.getLogger(MessageInputStream.class);
private static final CallbackBuffer EOF = new CallbackBuffer(Callback.NOOP, BufferUtil.EMPTY_BUFFER);
private final Deque<CallbackBuffer> buffers = new ArrayDeque<>(2);
private final AtomicBoolean closed = new AtomicBoolean(false);
private CallbackBuffer activeFrame;
private static final Entry EOF = new Entry(BufferUtil.EMPTY_BUFFER, Callback.NOOP);
private static final Entry CLOSED = new Entry(BufferUtil.EMPTY_BUFFER, Callback.NOOP);
private final BlockingArrayQueue<Entry> buffers = new BlockingArrayQueue<>();
private boolean closed = false;
private Entry currentEntry;
private long timeoutMs = -1;
@Override
public void accept(Frame frame, Callback callback)
@ -51,119 +55,28 @@ public class MessageInputStream extends InputStream implements MessageSink
if (LOG.isDebugEnabled())
LOG.debug("accepting {}", frame);
// If closed, we should just toss incoming payloads into the bit bucket.
if (closed.get())
boolean succeed = false;
synchronized (this)
{
callback.failed(new IOException("Already Closed"));
return;
}
if (!frame.hasPayload() && !frame.isFin())
{
callback.succeeded();
return;
}
synchronized (buffers)
{
boolean notify = false;
if (frame.hasPayload())
// If closed or we have no payload, request the next frame.
if (closed || (!frame.hasPayload() && !frame.isFin()))
{
buffers.offer(new CallbackBuffer(callback, frame.getPayload()));
notify = true;
succeed = true;
}
else
{
// We cannot wake up blocking read for a zero length frame.
callback.succeeded();
}
if (frame.hasPayload())
buffers.add(new Entry(frame.getPayload(), callback));
else
succeed = true;
if (frame.isFin())
{
buffers.offer(EOF);
notify = true;
}
if (notify)
{
// notify other thread
buffers.notify();
if (frame.isFin())
buffers.add(EOF);
}
}
}
@Override
public void close() throws IOException
{
if (LOG.isDebugEnabled())
LOG.debug("close()");
if (closed.compareAndSet(false, true))
{
synchronized (buffers)
{
buffers.offer(EOF);
buffers.notify();
}
}
super.close();
}
public CallbackBuffer getActiveFrame() throws InterruptedIOException
{
if (activeFrame == null)
{
// sync and poll queue
CallbackBuffer result;
synchronized (buffers)
{
try
{
while ((result = buffers.poll()) == null)
{
// TODO: handle read timeout here?
buffers.wait();
}
}
catch (InterruptedException e)
{
shutdown();
throw new InterruptedIOException();
}
}
activeFrame = result;
}
return activeFrame;
}
private void shutdown()
{
if (LOG.isDebugEnabled())
LOG.debug("shutdown()");
synchronized (buffers)
{
closed.set(true);
Throwable cause = new IOException("Shutdown");
for (CallbackBuffer buffer : buffers)
{
buffer.callback.failed(cause);
}
// Removed buffers that may have remained in the queue.
buffers.clear();
}
}
@Override
public void mark(int readlimit)
{
// Not supported.
}
@Override
public boolean markSupported()
{
return false;
if (succeed)
callback.succeeded();
}
@Override
@ -184,43 +97,142 @@ public class MessageInputStream extends InputStream implements MessageSink
@Override
public int read(final byte[] b, final int off, final int len) throws IOException
{
if (closed.get())
{
if (LOG.isDebugEnabled())
LOG.debug("Stream closed");
return -1;
}
CallbackBuffer result = getActiveFrame();
return read(ByteBuffer.wrap(b, off, len).flip());
}
public int read(ByteBuffer buffer) throws IOException
{
Entry currentEntry = getCurrentEntry();
if (LOG.isDebugEnabled())
LOG.debug("result = {}", result);
LOG.debug("currentEntry = {}", currentEntry);
if (result == EOF)
if (currentEntry == CLOSED)
throw new IOException("Closed");
if (currentEntry == EOF)
{
if (LOG.isDebugEnabled())
LOG.debug("Read EOF");
shutdown();
return -1;
}
// We have content
int fillLen = Math.min(result.buffer.remaining(), len);
result.buffer.get(b, off, fillLen);
// We have content.
int fillLen = BufferUtil.append(buffer, currentEntry.buffer);
if (!currentEntry.buffer.hasRemaining())
succeedCurrentEntry();
if (!result.buffer.hasRemaining())
{
activeFrame = null;
result.callback.succeeded();
}
// return number of bytes actually copied into buffer
// Return number of bytes actually copied into buffer.
if (LOG.isDebugEnabled())
LOG.debug("filled {} bytes from {}", fillLen, currentEntry);
return fillLen;
}
@Override
public void reset() throws IOException
public void close() throws IOException
{
throw new IOException("reset() not supported");
if (LOG.isDebugEnabled())
LOG.debug("close()");
ArrayList<Entry> entries = new ArrayList<>();
synchronized (this)
{
if (closed)
return;
closed = true;
if (currentEntry != null)
{
entries.add(currentEntry);
currentEntry = null;
}
// Clear queue and fail all entries.
entries.addAll(buffers);
buffers.clear();
buffers.offer(CLOSED);
}
// Succeed all entries as we don't need them anymore (failing would close the connection).
for (Entry e : entries)
{
e.callback.succeeded();
}
super.close();
}
public void setTimeout(long timeoutMs)
{
this.timeoutMs = timeoutMs;
}
private void succeedCurrentEntry()
{
Entry current;
synchronized (this)
{
current = currentEntry;
currentEntry = null;
}
if (current != null)
current.callback.succeeded();
}
private Entry getCurrentEntry() throws IOException
{
synchronized (this)
{
if (currentEntry != null)
return currentEntry;
}
try
{
if (LOG.isDebugEnabled())
LOG.debug("Waiting {} ms to read", timeoutMs);
Entry result;
if (timeoutMs < 0)
{
// Wait forever until a buffer is available.
result = buffers.take();
}
else
{
// Wait at most for the given timeout.
result = buffers.poll(timeoutMs, TimeUnit.MILLISECONDS);
if (result == null)
throw new IOException(String.format("Read timeout: %,dms expired", timeoutMs));
}
synchronized (this)
{
currentEntry = result;
return currentEntry;
}
}
catch (InterruptedException e)
{
close();
throw new InterruptedIOException();
}
}
private static class Entry
{
public ByteBuffer buffer;
public Callback callback;
public Entry(ByteBuffer buffer, Callback callback)
{
this.buffer = Objects.requireNonNull(buffer);
this.callback = callback;
}
@Override
public String toString()
{
return String.format("Entry[%s,%s]", BufferUtil.toDetailString(buffer), callback.getClass().getSimpleName());
}
}
}

View File

@ -55,7 +55,6 @@ public class MessageOutputStream extends OutputStream
this.bufferPool = bufferPool;
this.bufferSize = coreSession.getOutputBufferSize();
this.buffer = bufferPool.acquire(bufferSize, true);
BufferUtil.clear(buffer);
}
void setMessageType(byte opcode)
@ -93,6 +92,20 @@ public class MessageOutputStream extends OutputStream
}
}
public void write(ByteBuffer buffer) throws IOException
{
try
{
send(buffer);
}
catch (Throwable x)
{
// Notify without holding locks.
notifyFailure(x);
throw x;
}
}
@Override
public void flush() throws IOException
{

View File

@ -18,30 +18,83 @@
package org.eclipse.jetty.websocket.util.messages;
import java.io.InputStreamReader;
import java.io.IOException;
import java.io.Reader;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.CharsetDecoder;
import java.nio.charset.CoderResult;
import java.nio.charset.CodingErrorAction;
import java.nio.charset.StandardCharsets;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.websocket.core.Frame;
import org.eclipse.jetty.websocket.core.WebSocketConstants;
import static java.nio.charset.StandardCharsets.UTF_8;
/**
* Support class for reading a (single) WebSocket TEXT message via a Reader.
* <p>
* In compliance to the WebSocket spec, this reader always uses the {@link StandardCharsets#UTF_8}.
*/
public class MessageReader extends InputStreamReader implements MessageSink
public class MessageReader extends Reader implements MessageSink
{
private final MessageInputStream stream;
private static final int BUFFER_SIZE = WebSocketConstants.DEFAULT_INPUT_BUFFER_SIZE;
public MessageReader(MessageInputStream stream)
private final ByteBuffer buffer;
private final MessageInputStream stream;
private final CharsetDecoder utf8Decoder = UTF_8.newDecoder()
.onUnmappableCharacter(CodingErrorAction.REPORT)
.onMalformedInput(CodingErrorAction.REPORT);
public MessageReader()
{
super(stream, StandardCharsets.UTF_8);
this.stream = stream;
this(BUFFER_SIZE);
}
public MessageReader(int bufferSize)
{
this.stream = new MessageInputStream();
this.buffer = BufferUtil.allocate(bufferSize);
}
@Override
public int read(char[] cbuf, int off, int len) throws IOException
{
CharBuffer charBuffer = CharBuffer.wrap(cbuf, off, len);
boolean endOfInput = false;
while (true)
{
int read = stream.read(buffer);
if (read == 0)
break;
if (read < 0)
{
endOfInput = true;
break;
}
}
CoderResult result = utf8Decoder.decode(buffer, charBuffer, endOfInput);
if (result.isError())
result.throwException();
if (endOfInput && (charBuffer.position() == 0))
return -1;
return charBuffer.position();
}
@Override
public void close() throws IOException
{
stream.close();
}
@Override
public void accept(Frame frame, Callback callback)
{
this.stream.accept(frame, callback);
stream.accept(frame, callback);
}
}

View File

@ -20,19 +20,13 @@ package org.eclipse.jetty.websocket.util.messages;
import java.io.IOException;
import java.io.Writer;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.CharsetEncoder;
import java.nio.charset.CodingErrorAction;
import org.eclipse.jetty.io.ByteBufferPool;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.FutureCallback;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.Frame;
import org.eclipse.jetty.websocket.core.OpCode;
import static java.nio.charset.StandardCharsets.UTF_8;
@ -44,180 +38,38 @@ import static java.nio.charset.StandardCharsets.UTF_8;
*/
public class MessageWriter extends Writer
{
private static final Logger LOG = Log.getLogger(MessageWriter.class);
private final MessageOutputStream outputStream;
private final CharsetEncoder utf8Encoder = UTF_8.newEncoder()
.onUnmappableCharacter(CodingErrorAction.REPORT)
.onMalformedInput(CodingErrorAction.REPORT);
private final CoreSession coreSession;
private long frameCount;
private Frame frame;
private CharBuffer buffer;
private Callback callback;
private boolean closed;
public MessageWriter(CoreSession coreSession, ByteBufferPool bufferPool)
{
this.coreSession = coreSession;
this.buffer = CharBuffer.allocate(coreSession.getOutputBufferSize());
this.frame = new Frame(OpCode.TEXT);
this.outputStream = new MessageOutputStream(coreSession, bufferPool);
this.outputStream.setMessageType(OpCode.TEXT);
}
@Override
public void write(char[] chars, int off, int len) throws IOException
public void write(char[] cbuf, int off, int len) throws IOException
{
try
{
send(chars, off, len);
}
catch (Throwable x)
{
// Notify without holding locks.
notifyFailure(x);
throw x;
}
}
@Override
public void write(int c) throws IOException
{
try
{
send(new char[]{(char)c}, 0, 1);
}
catch (Throwable x)
{
// Notify without holding locks.
notifyFailure(x);
throw x;
}
CharBuffer charBuffer = CharBuffer.wrap(cbuf, off, len);
outputStream.write(utf8Encoder.encode(charBuffer));
}
@Override
public void flush() throws IOException
{
try
{
flush(false);
}
catch (Throwable x)
{
// Notify without holding locks.
notifyFailure(x);
throw x;
}
}
private void flush(boolean fin) throws IOException
{
synchronized (this)
{
if (closed)
throw new IOException("Stream is closed");
closed = fin;
buffer.flip();
ByteBuffer payload = utf8Encoder.encode(buffer);
buffer.flip();
if (LOG.isDebugEnabled())
LOG.debug("flush({}): {}", fin, BufferUtil.toDetailString(payload));
frame.setPayload(payload);
frame.setFin(fin);
FutureCallback b = new FutureCallback();
coreSession.sendFrame(frame, b, false);
b.block();
++frameCount;
// Any flush after the first will be a CONTINUATION frame.
frame = new Frame(OpCode.CONTINUATION);
}
}
private void send(char[] chars, int offset, int length) throws IOException
{
synchronized (this)
{
if (closed)
throw new IOException("Stream is closed");
CharBuffer source = CharBuffer.wrap(chars, offset, length);
int remaining = length;
while (remaining > 0)
{
int read = source.read(buffer);
if (read == -1)
{
return;
}
remaining -= read;
if (remaining > 0)
{
// If we could not write everything, it means
// that the buffer was full, so flush it.
flush(false);
}
}
}
outputStream.flush();
}
@Override
public void close() throws IOException
{
try
{
flush(true);
if (LOG.isDebugEnabled())
LOG.debug("Stream closed, {} frames sent", frameCount);
// Notify without holding locks.
notifySuccess();
}
catch (Throwable x)
{
// Notify without holding locks.
notifyFailure(x);
throw x;
}
outputStream.close();
}
public void setCallback(Callback callback)
{
synchronized (this)
{
this.callback = callback;
}
}
private void notifySuccess()
{
Callback callback;
synchronized (this)
{
callback = this.callback;
}
if (callback != null)
{
callback.succeeded();
}
}
private void notifyFailure(Throwable failure)
{
Callback callback;
synchronized (this)
{
callback = this.callback;
}
if (callback != null)
{
callback.failed(failure);
}
outputStream.setCallback(callback);
}
}

View File

@ -18,13 +18,12 @@
package org.eclipse.jetty.websocket.util.messages;
import java.io.Reader;
import java.lang.invoke.MethodHandle;
import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.Frame;
public class ReaderMessageSink extends DispatchedMessageSink<Reader>
public class ReaderMessageSink extends DispatchedMessageSink
{
public ReaderMessageSink(CoreSession session, MethodHandle methodHandle)
{
@ -34,6 +33,6 @@ public class ReaderMessageSink extends DispatchedMessageSink<Reader>
@Override
public MessageReader newSink(Frame frame)
{
return new MessageReader(new MessageInputStream());
return new MessageReader(session.getInputBufferSize());
}
}

View File

@ -0,0 +1,139 @@
//
// ========================================================================
// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under
// the terms of the Eclipse Public License 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0
//
// This Source Code may also be made available under the following
// Secondary Licenses when the conditions for such availability set
// forth in the Eclipse Public License, v. 2.0 are satisfied:
// the Apache License v2.0 which is available at
// https://www.apache.org/licenses/LICENSE-2.0
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.websocket.util;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.MalformedInputException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.FutureCallback;
import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.websocket.core.Frame;
import org.eclipse.jetty.websocket.core.OpCode;
import org.eclipse.jetty.websocket.util.messages.MessageReader;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class MessageReaderTest
{
private final MessageReader reader = new MessageReader();
private final CompletableFuture<String> message = new CompletableFuture<>();
private boolean first = true;
@BeforeEach
public void before()
{
// Read the message in a different thread.
new Thread(() ->
{
try
{
message.complete(IO.toString(reader));
}
catch (IOException e)
{
message.completeExceptionally(e);
}
}).start();
}
@Test
public void testSingleFrameMessage() throws Exception
{
giveString("hello world!", true);
String s = message.get(5, TimeUnit.SECONDS);
assertThat(s, is("hello world!"));
}
@Test
public void testFragmentedMessage() throws Exception
{
giveString("hello", false);
giveString(" ", false);
giveString("world", false);
giveString("!", true);
String s = message.get(5, TimeUnit.SECONDS);
assertThat(s, is("hello world!"));
}
@Test
public void testEmptySegments() throws Exception
{
giveString("", false);
giveString("hello ", false);
giveString("", false);
giveString("", false);
giveString("world!", false);
giveString("", false);
giveString("", true);
String s = message.get(5, TimeUnit.SECONDS);
assertThat(s, is("hello world!"));
}
@Test
public void testCloseStream() throws Exception
{
giveString("hello ", false);
reader.close();
giveString("world!", true);
ExecutionException error = assertThrows(ExecutionException.class, () -> message.get(5, TimeUnit.SECONDS));
Throwable cause = error.getCause();
assertThat(cause, instanceOf(IOException.class));
assertThat(cause.getMessage(), is("Closed"));
}
@Test
public void testInvalidUtf8() throws Exception
{
ByteBuffer invalidUtf8Payload = BufferUtil.toBuffer(new byte[]{0x7F, (byte)0xFF, (byte)0xFF});
giveByteBuffer(invalidUtf8Payload, true);
ExecutionException error = assertThrows(ExecutionException.class, () -> message.get(5, TimeUnit.SECONDS));
assertThat(error.getCause(), instanceOf(MalformedInputException.class));
}
private void giveString(String s, boolean last) throws IOException
{
giveByteBuffer(ByteBuffer.wrap(StringUtil.getUtf8Bytes(s)), last);
}
private void giveByteBuffer(ByteBuffer buffer, boolean last) throws IOException
{
byte opCode = first ? OpCode.TEXT : OpCode.CONTINUATION;
Frame frame = new Frame(opCode, last, buffer);
FutureCallback callback = new FutureCallback();
reader.accept(frame, callback);
callback.block(5, TimeUnit.SECONDS);
first = false;
}
}

View File

@ -16,9 +16,10 @@
// ========================================================================
//
package org.eclipse.jetty.websocket.javax.common.messages;
package org.eclipse.jetty.websocket.util;
import java.io.IOException;
import java.nio.charset.MalformedInputException;
import java.util.Arrays;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
@ -36,10 +37,72 @@ import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class MessageWriterTest
{
private ByteBufferPool bufferPool = new MappedByteBufferPool();
private final CoreSession coreSession = new CoreSession.Empty();
private final ByteBufferPool bufferPool = new MappedByteBufferPool();
@Test
public void testMultipleWrites() throws Exception
{
WholeMessageCapture capture = new WholeMessageCapture();
try (MessageWriter stream = new MessageWriter(capture, bufferPool))
{
stream.write("Hello");
stream.write(" ");
stream.write("World");
}
assertThat("Socket.messageQueue.size", capture.messages.size(), is(1));
String msg = capture.messages.poll();
assertThat("Message", msg, is("Hello World"));
}
@Test
public void testSingleWrite() throws Exception
{
WholeMessageCapture capture = new WholeMessageCapture();
try (MessageWriter stream = new MessageWriter(capture, bufferPool))
{
stream.append("Hello World");
}
assertThat("Socket.messageQueue.size", capture.messages.size(), is(1));
String msg = capture.messages.poll();
assertThat("Message", msg, is("Hello World"));
}
@Test
public void testWriteLargeRequiringMultipleBuffers() throws Exception
{
int outputBufferSize = 4096;
int size = (int)(outputBufferSize * 2.5);
char[] buf = new char[size];
Arrays.fill(buf, 'x');
buf[size - 1] = 'o'; // mark last entry for debugging
WholeMessageCapture capture = new WholeMessageCapture();
try (MessageWriter stream = new MessageWriter(capture, bufferPool))
{
stream.write(buf);
}
assertThat("Socket.messageQueue.size", capture.messages.size(), is(1));
String msg = capture.messages.poll();
String expected = new String(buf);
assertThat("Message", msg, is(expected));
}
@Test
public void testInvalidUtf8()
{
final String invalidUtf8String = "\uD800";
MessageWriter writer = new MessageWriter(coreSession, bufferPool);
assertThrows(MalformedInputException.class, () -> writer.write(invalidUtf8String.toCharArray()));
}
@Test
public void testSingleByteArray512b() throws IOException, InterruptedException