Merge pull request #5377 from eclipse/jetty-9.4.x-5368-WebSocketInputStream

Issue #5368 - ensure onMessage exits before next frame is read
This commit is contained in:
Lachlan 2020-10-16 15:51:54 +11:00 committed by GitHub
commit f99b4ca80c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 253 additions and 95 deletions

View File

@ -118,10 +118,8 @@ public class JsrAnnotatedEventDriver extends AbstractJsrEventDriver
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("Binary Message InputStream"); LOG.debug("Binary Message InputStream");
final MessageInputStream stream = new MessageInputStream(session); MessageInputStream stream = new MessageInputStream(session);
activeMessage = stream; activeMessage = stream;
// Always dispatch streaming read to another thread.
dispatch(() -> dispatch(() ->
{ {
try try
@ -133,7 +131,7 @@ public class JsrAnnotatedEventDriver extends AbstractJsrEventDriver
session.close(e); session.close(e);
} }
stream.close(); stream.handlerComplete();
}); });
} }
} }
@ -329,11 +327,8 @@ public class JsrAnnotatedEventDriver extends AbstractJsrEventDriver
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("Text Message Writer"); LOG.debug("Text Message Writer");
MessageInputStream inputStream = new MessageInputStream(session); MessageReader reader = new MessageReader(session);
final MessageReader reader = new MessageReader(inputStream); activeMessage = reader;
activeMessage = inputStream;
// Always dispatch streaming read to another thread.
dispatch(() -> dispatch(() ->
{ {
try try
@ -343,9 +338,10 @@ public class JsrAnnotatedEventDriver extends AbstractJsrEventDriver
catch (Throwable e) catch (Throwable e)
{ {
session.close(e); session.close(e);
return;
} }
inputStream.close(); reader.handlerComplete();
}); });
} }
} }

View File

@ -100,9 +100,10 @@ public class JsrEndpointEventDriver extends AbstractJsrEventDriver
catch (Throwable t) catch (Throwable t)
{ {
session.close(t); session.close(t);
return;
} }
inputStream.close(); inputStream.handlerComplete();
}); });
} }
else else
@ -197,8 +198,7 @@ public class JsrEndpointEventDriver extends AbstractJsrEventDriver
{ {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
MessageHandler.Whole<Reader> handler = (Whole<Reader>)wrapper.getHandler(); MessageHandler.Whole<Reader> handler = (Whole<Reader>)wrapper.getHandler();
MessageInputStream inputStream = new MessageInputStream(session); MessageReader reader = new MessageReader(session);
MessageReader reader = new MessageReader(inputStream);
activeMessage = reader; activeMessage = reader;
dispatch(() -> dispatch(() ->
{ {
@ -209,9 +209,10 @@ public class JsrEndpointEventDriver extends AbstractJsrEventDriver
catch (Throwable t) catch (Throwable t)
{ {
session.close(t); session.close(t);
return;
} }
inputStream.close(); reader.handlerComplete();
}); });
} }
else else

View File

@ -27,7 +27,9 @@ import java.util.Random;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import javax.websocket.ClientEndpoint; import javax.websocket.ClientEndpoint;
import javax.websocket.CloseReason;
import javax.websocket.ContainerProvider; import javax.websocket.ContainerProvider;
import javax.websocket.OnClose;
import javax.websocket.OnMessage; import javax.websocket.OnMessage;
import javax.websocket.Session; import javax.websocket.Session;
import javax.websocket.WebSocketContainer; import javax.websocket.WebSocketContainer;
@ -37,11 +39,15 @@ import javax.websocket.server.ServerEndpointConfig;
import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector; import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.servlet.ServletContextHandler; import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.websocket.jsr356.server.deploy.WebSocketServerContainerInitializer; import org.eclipse.jetty.websocket.jsr356.server.deploy.WebSocketServerContainerInitializer;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@ -127,6 +133,62 @@ public class BinaryStreamTest
assertArrayEquals(data, client.getEcho()); assertArrayEquals(data, client.getEcho());
} }
@Test
public void testNotReadingToEndOfStream() throws Exception
{
int size = 32;
byte[] data = randomBytes(size);
URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + PATH);
CountDownLatch handlerComplete = new CountDownLatch(1);
BasicClientBinaryStreamer client = new BasicClientBinaryStreamer((session, inputStream) ->
{
byte[] recv = new byte[16];
int read = inputStream.read(recv);
assertThat(read, not(is(0)));
handlerComplete.countDown();
});
Session session = wsClient.connectToServer(client, uri);
session.getBasicRemote().sendBinary(BufferUtil.toBuffer(data));
assertTrue(handlerComplete.await(5, TimeUnit.SECONDS));
session.close(new CloseReason(CloseReason.CloseCodes.NORMAL_CLOSURE, "close from test"));
assertTrue(client.closeLatch.await(5, TimeUnit.SECONDS));
assertThat(client.closeReason.getCloseCode(), is(CloseReason.CloseCodes.NORMAL_CLOSURE));
assertThat(client.closeReason.getReasonPhrase(), is("close from test"));
}
@Test
public void testClosingBeforeReadingToEndOfStream() throws Exception
{
int size = 32;
byte[] data = randomBytes(size);
URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + PATH);
CountDownLatch handlerComplete = new CountDownLatch(1);
BasicClientBinaryStreamer client = new BasicClientBinaryStreamer((session, inputStream) ->
{
byte[] recv = new byte[16];
int read = inputStream.read(recv);
assertThat(read, not(is(0)));
inputStream.close();
read = inputStream.read(recv);
assertThat(read, is(-1));
handlerComplete.countDown();
});
Session session = wsClient.connectToServer(client, uri);
session.getBasicRemote().sendBinary(BufferUtil.toBuffer(data));
assertTrue(handlerComplete.await(5, TimeUnit.SECONDS));
session.close(new CloseReason(CloseReason.CloseCodes.NORMAL_CLOSURE, "close from test"));
assertTrue(client.closeLatch.await(5, TimeUnit.SECONDS));
assertThat(client.closeReason.getCloseCode(), is(CloseReason.CloseCodes.NORMAL_CLOSURE));
assertThat(client.closeReason.getReasonPhrase(), is("close from test"));
}
private byte[] randomBytes(int size) private byte[] randomBytes(int size)
{ {
byte[] data = new byte[size]; byte[] data = new byte[size];
@ -134,6 +196,37 @@ public class BinaryStreamTest
return data; return data;
} }
@ClientEndpoint
public static class BasicClientBinaryStreamer
{
public interface MessageHandler
{
void accept(Session session, InputStream inputStream) throws Exception;
}
private final MessageHandler handler;
private final CountDownLatch closeLatch = new CountDownLatch(1);
private CloseReason closeReason;
public BasicClientBinaryStreamer(MessageHandler consumer)
{
this.handler = consumer;
}
@OnMessage
public void echoed(Session session, InputStream input) throws Exception
{
handler.accept(session, input);
}
@OnClose
public void onClosed(CloseReason closeReason)
{
this.closeReason = closeReason;
closeLatch.countDown();
}
}
@ClientEndpoint @ClientEndpoint
public static class ClientBinaryStreamer public static class ClientBinaryStreamer
{ {

View File

@ -32,7 +32,6 @@ import org.eclipse.jetty.websocket.api.WebSocketPolicy;
import org.eclipse.jetty.websocket.api.annotations.WebSocket; import org.eclipse.jetty.websocket.api.annotations.WebSocket;
import org.eclipse.jetty.websocket.api.extensions.Frame; import org.eclipse.jetty.websocket.api.extensions.Frame;
import org.eclipse.jetty.websocket.common.CloseInfo; import org.eclipse.jetty.websocket.common.CloseInfo;
import org.eclipse.jetty.websocket.common.message.MessageAppender;
import org.eclipse.jetty.websocket.common.message.MessageInputStream; import org.eclipse.jetty.websocket.common.message.MessageInputStream;
import org.eclipse.jetty.websocket.common.message.MessageReader; import org.eclipse.jetty.websocket.common.message.MessageReader;
import org.eclipse.jetty.websocket.common.message.NullMessage; import org.eclipse.jetty.websocket.common.message.NullMessage;
@ -105,7 +104,7 @@ public class JettyAnnotatedEventDriver extends AbstractEventDriver
} }
else if (events.onBinary.isStreaming()) else if (events.onBinary.isStreaming())
{ {
final MessageInputStream inputStream = new MessageInputStream(session); MessageInputStream inputStream = new MessageInputStream(session);
activeMessage = inputStream; activeMessage = inputStream;
dispatch(() -> dispatch(() ->
{ {
@ -115,11 +114,11 @@ public class JettyAnnotatedEventDriver extends AbstractEventDriver
} }
catch (Throwable t) catch (Throwable t)
{ {
// dispatched calls need to be reported
session.close(t); session.close(t);
return;
} }
inputStream.close(); inputStream.handlerComplete();
}); });
} }
else else
@ -262,22 +261,21 @@ public class JettyAnnotatedEventDriver extends AbstractEventDriver
} }
else if (events.onText.isStreaming()) else if (events.onText.isStreaming())
{ {
MessageInputStream inputStream = new MessageInputStream(session); MessageReader reader = new MessageReader(session);
activeMessage = new MessageReader(inputStream); activeMessage = reader;
final MessageAppender msg = activeMessage;
dispatch(() -> dispatch(() ->
{ {
try try
{ {
events.onText.call(websocket, session, msg); events.onText.call(websocket, session, reader);
} }
catch (Throwable t) catch (Throwable t)
{ {
// dispatched calls need to be reported
session.close(t); session.close(t);
return;
} }
inputStream.close(); reader.handlerComplete();
}); });
} }
else else

View File

@ -521,7 +521,7 @@ public abstract class AbstractWebSocketConnection extends AbstractConnection imp
{ {
ByteBuffer resume = readState.resume(); ByteBuffer resume = readState.resume();
if (resume != null) if (resume != null)
onFillable(resume); getExecutor().execute(() -> onFillable(resume));
} }
@Override @Override

View File

@ -53,8 +53,24 @@ public class MessageInputStream extends InputStream implements MessageAppender
private enum State private enum State
{ {
/**
* Open and waiting for a frame to be delivered in {@link #appendFrame(ByteBuffer, boolean)}.
*/
RESUMED, RESUMED,
/**
* We have suspended the session after reading a websocket frame but have not reached the end of the message.
*/
SUSPENDED, SUSPENDED,
/**
* We have received a frame with fin==true and have suspended until we are signaled that onMessage method exited.
*/
COMPLETE,
/**
* We have read to EOF or someone has called InputStream.close(), any further reads will result in reading -1.
*/
CLOSED CLOSED
} }
@ -76,24 +92,16 @@ public class MessageInputStream extends InputStream implements MessageAppender
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("Appending {} chunk: {}", fin ? "final" : "non-final", BufferUtil.toDetailString(framePayload)); LOG.debug("Appending {} chunk: {}", fin ? "final" : "non-final", BufferUtil.toDetailString(framePayload));
// Early non atomic test that we aren't closed to avoid an unnecessary copy (will be checked again later). // Avoid entering synchronized block if there is nothing to do.
if (state == State.CLOSED) boolean bufferIsEmpty = BufferUtil.isEmpty(framePayload);
if (bufferIsEmpty && !fin)
return; return;
// Put the payload into the queue, by copying it.
// Copying is necessary because the payload will
// be processed after this method returns.
try try
{ {
if (framePayload == null || !framePayload.hasRemaining())
return;
ByteBuffer copy = acquire(framePayload.remaining(), framePayload.isDirect());
BufferUtil.clearToFill(copy);
copy.put(framePayload);
BufferUtil.flipToFlush(copy, 0);
synchronized (this) synchronized (this)
{
if (!bufferIsEmpty)
{ {
switch (state) switch (state)
{ {
@ -105,12 +113,26 @@ public class MessageInputStream extends InputStream implements MessageAppender
state = State.SUSPENDED; state = State.SUSPENDED;
break; break;
case SUSPENDED: default:
throw new IllegalStateException(); throw new IllegalStateException("Incorrect State: " + state.name());
} }
// Put the payload into the queue, by copying it.
// Copying is necessary because the payload will
// be processed after this method returns.
ByteBuffer copy = acquire(framePayload.remaining(), framePayload.isDirect());
BufferUtil.clearToFill(copy);
copy.put(framePayload);
BufferUtil.flipToFlush(copy, 0);
buffers.put(copy); buffers.put(copy);
} }
if (fin)
{
buffers.add(EOF);
state = State.COMPLETE;
}
}
} }
catch (InterruptedException e) catch (InterruptedException e)
{ {
@ -121,56 +143,59 @@ public class MessageInputStream extends InputStream implements MessageAppender
@Override @Override
public void close() public void close()
{ {
SuspendToken resume = null;
synchronized (this) synchronized (this)
{ {
switch (state) if (state == State.CLOSED)
{
case CLOSED:
return; return;
case SUSPENDED: boolean remainingContent = (state != State.COMPLETE) ||
(!buffers.isEmpty() && buffers.peek() != EOF) ||
(activeBuffer != null && activeBuffer.hasRemaining());
if (remainingContent)
LOG.warn("MessageInputStream closed without fully consuming content {}", session);
state = State.CLOSED;
buffers.clear();
buffers.add(EOF);
}
}
public void handlerComplete()
{
// Close the InputStream.
close();
// May need to resume to resume and read to the next message.
SuspendToken resume;
synchronized (this)
{
resume = suspendToken; resume = suspendToken;
suspendToken = null; suspendToken = null;
state = State.CLOSED;
break;
case RESUMED:
state = State.CLOSED;
break;
} }
buffers.clear();
buffers.offer(EOF);
}
// May need to resume to discard until we reach next message.
if (resume != null) if (resume != null)
resume.resume(); resume.resume();
} }
@Override
public void mark(int readlimit)
{
// Not supported.
}
@Override
public boolean markSupported()
{
return false;
}
@Override
public void messageComplete()
{
if (LOG.isDebugEnabled())
LOG.debug("Message completed");
buffers.offer(EOF);
}
@Override @Override
public int read() throws IOException public int read() throws IOException
{
byte[] bytes = new byte[1];
while (true)
{
int read = read(bytes, 0, 1);
if (read < 0)
return -1;
if (read == 0)
continue;
return bytes[0] & 0xFF;
}
}
@Override
public int read(byte[] b, int off, int len) throws IOException
{ {
try try
{ {
@ -186,6 +211,7 @@ public class MessageInputStream extends InputStream implements MessageAppender
{ {
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("Waiting {} ms to read", timeoutMs); LOG.debug("Waiting {} ms to read", timeoutMs);
if (timeoutMs < 0) if (timeoutMs < 0)
{ {
// Wait forever until a buffer is available. // Wait forever until a buffer is available.
@ -209,10 +235,14 @@ public class MessageInputStream extends InputStream implements MessageAppender
} }
} }
int result = activeBuffer.get() & 0xFF; ByteBuffer buffer = BufferUtil.toBuffer(b, off, len);
BufferUtil.clearToFill(buffer);
int written = BufferUtil.put(activeBuffer, buffer);
BufferUtil.flipToFlush(buffer, 0);
// If we have no more content we may need to resume to get more data.
if (!activeBuffer.hasRemaining()) if (!activeBuffer.hasRemaining())
{ {
SuspendToken resume = null; SuspendToken resume = null;
synchronized (this) synchronized (this)
{ {
@ -221,6 +251,11 @@ public class MessageInputStream extends InputStream implements MessageAppender
case CLOSED: case CLOSED:
return -1; return -1;
case COMPLETE:
// If we are complete we have read the last frame but
// don't want to resume reading until onMessage() exits.
break;
case SUSPENDED: case SUSPENDED:
resume = suspendToken; resume = suspendToken;
suspendToken = null; suspendToken = null;
@ -228,7 +263,7 @@ public class MessageInputStream extends InputStream implements MessageAppender
break; break;
case RESUMED: case RESUMED:
throw new IllegalStateException(); throw new IllegalStateException("Incorrect State: " + state.name());
} }
} }
@ -237,7 +272,7 @@ public class MessageInputStream extends InputStream implements MessageAppender
resume.resume(); resume.resume();
} }
return result; return written;
} }
catch (InterruptedException x) catch (InterruptedException x)
{ {
@ -248,12 +283,30 @@ public class MessageInputStream extends InputStream implements MessageAppender
} }
} }
@Override
public void messageComplete()
{
// We handle this case in appendFrame with fin==true.
}
@Override @Override
public void reset() throws IOException public void reset() throws IOException
{ {
throw new IOException("reset() not supported"); throw new IOException("reset() not supported");
} }
@Override
public void mark(int readlimit)
{
// Not supported.
}
@Override
public boolean markSupported()
{
return false;
}
private ByteBuffer acquire(int capacity, boolean direct) private ByteBuffer acquire(int capacity, boolean direct)
{ {
ByteBuffer buffer; ByteBuffer buffer;

View File

@ -24,6 +24,8 @@ import java.nio.ByteBuffer;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import org.eclipse.jetty.websocket.api.Session;
/** /**
* Support class for reading a (single) WebSocket TEXT message via a Reader. * Support class for reading a (single) WebSocket TEXT message via a Reader.
* <p> * <p>
@ -33,6 +35,11 @@ public class MessageReader extends InputStreamReader implements MessageAppender
{ {
private final MessageInputStream stream; private final MessageInputStream stream;
public MessageReader(Session session)
{
this(new MessageInputStream(session));
}
public MessageReader(MessageInputStream stream) public MessageReader(MessageInputStream stream)
{ {
super(stream, StandardCharsets.UTF_8); super(stream, StandardCharsets.UTF_8);
@ -50,4 +57,9 @@ public class MessageReader extends InputStreamReader implements MessageAppender
{ {
this.stream.messageComplete(); this.stream.messageComplete();
} }
public void handlerComplete()
{
this.stream.handlerComplete();
}
} }

View File

@ -18,6 +18,7 @@
package org.eclipse.jetty.websocket.common.message; package org.eclipse.jetty.websocket.common.message;
import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
@ -32,6 +33,7 @@ import org.eclipse.jetty.toolchain.test.jupiter.WorkDir;
import org.eclipse.jetty.toolchain.test.jupiter.WorkDirExtension; import org.eclipse.jetty.toolchain.test.jupiter.WorkDirExtension;
import org.eclipse.jetty.util.BlockingArrayQueue; import org.eclipse.jetty.util.BlockingArrayQueue;
import org.eclipse.jetty.util.BufferUtil; import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.websocket.api.SuspendToken; import org.eclipse.jetty.websocket.api.SuspendToken;
import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -111,9 +113,10 @@ public class MessageInputStreamTest
startLatch.await(); startLatch.await();
// Read it from the stream. // Read it from the stream.
byte[] buf = new byte[32]; ByteArrayOutputStream out = new ByteArrayOutputStream();
int len = stream.read(buf); IO.copy(stream, out);
String message = new String(buf, 0, len, StandardCharsets.UTF_8); byte[] bytes = out.toByteArray();
String message = new String(bytes, 0, bytes.length, StandardCharsets.UTF_8);
// Test it // Test it
assertThat("Error when appending", hadError.get(), is(false)); assertThat("Error when appending", hadError.get(), is(false));
@ -169,9 +172,10 @@ public class MessageInputStreamTest
{ {
// wait for a little bit before sending input closed // wait for a little bit before sending input closed
TimeUnit.MILLISECONDS.sleep(1000); TimeUnit.MILLISECONDS.sleep(1000);
stream.appendFrame(null, true);
stream.messageComplete(); stream.messageComplete();
} }
catch (InterruptedException e) catch (InterruptedException | IOException e)
{ {
hadError.set(true); hadError.set(true);
e.printStackTrace(System.err); e.printStackTrace(System.err);
@ -206,9 +210,10 @@ public class MessageInputStreamTest
session.provideContent(); session.provideContent();
// Read entire message it from the stream. // Read entire message it from the stream.
byte[] buf = new byte[32]; ByteArrayOutputStream out = new ByteArrayOutputStream();
int len = stream.read(buf); IO.copy(stream, out);
String message = new String(buf, 0, len, StandardCharsets.UTF_8); byte[] bytes = out.toByteArray();
String message = new String(bytes, 0, bytes.length, StandardCharsets.UTF_8);
// Test it // Test it
assertThat("Message", message, is("Hello World!")); assertThat("Message", message, is("Hello World!"));