Issue #5368 - ensure onMessage exits before next frame is read

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
Lachlan Roberts 2020-10-01 15:42:50 +10:00
parent e3ed05fc1c
commit 941ffcead7
5 changed files with 98 additions and 67 deletions

View File

@ -118,10 +118,8 @@ public class JsrAnnotatedEventDriver extends AbstractJsrEventDriver
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("Binary Message InputStream"); LOG.debug("Binary Message InputStream");
final MessageInputStream stream = new MessageInputStream(session); MessageInputStream stream = new MessageInputStream(session);
activeMessage = stream; activeMessage = stream;
// Always dispatch streaming read to another thread.
dispatch(() -> dispatch(() ->
{ {
try try
@ -329,11 +327,8 @@ public class JsrAnnotatedEventDriver extends AbstractJsrEventDriver
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("Text Message Writer"); LOG.debug("Text Message Writer");
MessageInputStream inputStream = new MessageInputStream(session); MessageReader reader = new MessageReader(session);
final MessageReader reader = new MessageReader(inputStream); activeMessage = reader;
activeMessage = inputStream;
// Always dispatch streaming read to another thread.
dispatch(() -> dispatch(() ->
{ {
try try
@ -343,9 +338,10 @@ public class JsrAnnotatedEventDriver extends AbstractJsrEventDriver
catch (Throwable e) catch (Throwable e)
{ {
session.close(e); session.close(e);
return;
} }
inputStream.close(); reader.handlerComplete();
}); });
} }
} }

View File

@ -100,9 +100,10 @@ public class JsrEndpointEventDriver extends AbstractJsrEventDriver
catch (Throwable t) catch (Throwable t)
{ {
session.close(t); session.close(t);
return;
} }
inputStream.close(); inputStream.handlerComplete();
}); });
} }
else else
@ -197,8 +198,7 @@ public class JsrEndpointEventDriver extends AbstractJsrEventDriver
{ {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
MessageHandler.Whole<Reader> handler = (Whole<Reader>)wrapper.getHandler(); MessageHandler.Whole<Reader> handler = (Whole<Reader>)wrapper.getHandler();
MessageInputStream inputStream = new MessageInputStream(session); MessageReader reader = new MessageReader(session);
MessageReader reader = new MessageReader(inputStream);
activeMessage = reader; activeMessage = reader;
dispatch(() -> dispatch(() ->
{ {
@ -209,9 +209,10 @@ public class JsrEndpointEventDriver extends AbstractJsrEventDriver
catch (Throwable t) catch (Throwable t)
{ {
session.close(t); session.close(t);
return;
} }
inputStream.close(); reader.handlerComplete();
}); });
} }
else else

View File

@ -32,7 +32,6 @@ import org.eclipse.jetty.websocket.api.WebSocketPolicy;
import org.eclipse.jetty.websocket.api.annotations.WebSocket; import org.eclipse.jetty.websocket.api.annotations.WebSocket;
import org.eclipse.jetty.websocket.api.extensions.Frame; import org.eclipse.jetty.websocket.api.extensions.Frame;
import org.eclipse.jetty.websocket.common.CloseInfo; import org.eclipse.jetty.websocket.common.CloseInfo;
import org.eclipse.jetty.websocket.common.message.MessageAppender;
import org.eclipse.jetty.websocket.common.message.MessageInputStream; import org.eclipse.jetty.websocket.common.message.MessageInputStream;
import org.eclipse.jetty.websocket.common.message.MessageReader; import org.eclipse.jetty.websocket.common.message.MessageReader;
import org.eclipse.jetty.websocket.common.message.NullMessage; import org.eclipse.jetty.websocket.common.message.NullMessage;
@ -105,7 +104,7 @@ public class JettyAnnotatedEventDriver extends AbstractEventDriver
} }
else if (events.onBinary.isStreaming()) else if (events.onBinary.isStreaming())
{ {
final MessageInputStream inputStream = new MessageInputStream(session); MessageInputStream inputStream = new MessageInputStream(session);
activeMessage = inputStream; activeMessage = inputStream;
dispatch(() -> dispatch(() ->
{ {
@ -115,11 +114,11 @@ public class JettyAnnotatedEventDriver extends AbstractEventDriver
} }
catch (Throwable t) catch (Throwable t)
{ {
// dispatched calls need to be reported
session.close(t); session.close(t);
return;
} }
inputStream.close(); inputStream.handlerComplete();
}); });
} }
else else
@ -262,22 +261,21 @@ public class JettyAnnotatedEventDriver extends AbstractEventDriver
} }
else if (events.onText.isStreaming()) else if (events.onText.isStreaming())
{ {
MessageInputStream inputStream = new MessageInputStream(session); MessageReader reader = new MessageReader(session);
activeMessage = new MessageReader(inputStream); activeMessage = reader;
final MessageAppender msg = activeMessage;
dispatch(() -> dispatch(() ->
{ {
try try
{ {
events.onText.call(websocket, session, msg); events.onText.call(websocket, session, reader);
} }
catch (Throwable t) catch (Throwable t)
{ {
// dispatched calls need to be reported
session.close(t); session.close(t);
return;
} }
inputStream.close(); reader.handlerComplete();
}); });
} }
else else

View File

@ -55,6 +55,7 @@ public class MessageInputStream extends InputStream implements MessageAppender
{ {
RESUMED, RESUMED,
SUSPENDED, SUSPENDED,
COMPLETE,
CLOSED CLOSED
} }
@ -76,23 +77,11 @@ public class MessageInputStream extends InputStream implements MessageAppender
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("Appending {} chunk: {}", fin ? "final" : "non-final", BufferUtil.toDetailString(framePayload)); LOG.debug("Appending {} chunk: {}", fin ? "final" : "non-final", BufferUtil.toDetailString(framePayload));
// Early non atomic test that we aren't closed to avoid an unnecessary copy (will be checked again later).
if (state == State.CLOSED)
return;
// Put the payload into the queue, by copying it.
// Copying is necessary because the payload will
// be processed after this method returns.
try try
{ {
if (framePayload == null || !framePayload.hasRemaining()) if (BufferUtil.isEmpty(framePayload))
return; return;
ByteBuffer copy = acquire(framePayload.remaining(), framePayload.isDirect());
BufferUtil.clearToFill(copy);
copy.put(framePayload);
BufferUtil.flipToFlush(copy, 0);
synchronized (this) synchronized (this)
{ {
switch (state) switch (state)
@ -105,11 +94,14 @@ public class MessageInputStream extends InputStream implements MessageAppender
state = State.SUSPENDED; state = State.SUSPENDED;
break; break;
case SUSPENDED: default:
throw new IllegalStateException(); throw new IllegalStateException();
} }
buffers.put(copy); // Put the payload into the queue, by copying it.
// Copying is necessary because the payload will
// be processed after this method returns.
buffers.put(copy(framePayload));
} }
} }
catch (InterruptedException e) catch (InterruptedException e)
@ -121,7 +113,23 @@ public class MessageInputStream extends InputStream implements MessageAppender
@Override @Override
public void close() public void close()
{ {
SuspendToken resume = null; synchronized (this)
{
if (state == State.CLOSED)
return;
state = State.CLOSED;
buffers.clear();
buffers.offer(EOF);
}
}
@Override
public void messageComplete()
{
if (LOG.isDebugEnabled())
LOG.debug("Message completed");
synchronized (this) synchronized (this)
{ {
switch (state) switch (state)
@ -130,45 +138,35 @@ public class MessageInputStream extends InputStream implements MessageAppender
return; return;
case SUSPENDED: case SUSPENDED:
resume = suspendToken; case RESUMED:
suspendToken = null; state = State.COMPLETE;
state = State.CLOSED;
break; break;
case RESUMED: default:
state = State.CLOSED; throw new IllegalStateException();
break;
} }
buffers.offer(EOF);
}
}
public void handlerComplete()
{
// May need to resume to resume and read to the next message.
SuspendToken resume;
synchronized (this)
{
state = State.CLOSED;
resume = suspendToken;
suspendToken = null;
buffers.clear(); buffers.clear();
buffers.offer(EOF); buffers.offer(EOF);
} }
// May need to resume to discard until we reach next message.
if (resume != null) if (resume != null)
resume.resume(); resume.resume();
} }
@Override
public void mark(int readlimit)
{
// Not supported.
}
@Override
public boolean markSupported()
{
return false;
}
@Override
public void messageComplete()
{
if (LOG.isDebugEnabled())
LOG.debug("Message completed");
buffers.offer(EOF);
}
@Override @Override
public int read() throws IOException public int read() throws IOException
{ {
@ -186,6 +184,7 @@ public class MessageInputStream extends InputStream implements MessageAppender
{ {
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("Waiting {} ms to read", timeoutMs); LOG.debug("Waiting {} ms to read", timeoutMs);
if (timeoutMs < 0) if (timeoutMs < 0)
{ {
// Wait forever until a buffer is available. // Wait forever until a buffer is available.
@ -212,7 +211,6 @@ public class MessageInputStream extends InputStream implements MessageAppender
int result = activeBuffer.get() & 0xFF; int result = activeBuffer.get() & 0xFF;
if (!activeBuffer.hasRemaining()) if (!activeBuffer.hasRemaining())
{ {
SuspendToken resume = null; SuspendToken resume = null;
synchronized (this) synchronized (this)
{ {
@ -221,6 +219,11 @@ public class MessageInputStream extends InputStream implements MessageAppender
case CLOSED: case CLOSED:
return -1; return -1;
case COMPLETE:
// If we are complete we have read the last frame but
// don't want to resume reading until onMessage() exits.
break;
case SUSPENDED: case SUSPENDED:
resume = suspendToken; resume = suspendToken;
suspendToken = null; suspendToken = null;
@ -254,6 +257,27 @@ public class MessageInputStream extends InputStream implements MessageAppender
throw new IOException("reset() not supported"); throw new IOException("reset() not supported");
} }
@Override
public void mark(int readlimit)
{
// Not supported.
}
@Override
public boolean markSupported()
{
return false;
}
private ByteBuffer copy(ByteBuffer buffer)
{
ByteBuffer copy = acquire(buffer.remaining(), buffer.isDirect());
BufferUtil.clearToFill(copy);
copy.put(buffer);
BufferUtil.flipToFlush(copy, 0);
return copy;
}
private ByteBuffer acquire(int capacity, boolean direct) private ByteBuffer acquire(int capacity, boolean direct)
{ {
ByteBuffer buffer; ByteBuffer buffer;

View File

@ -24,6 +24,8 @@ import java.nio.ByteBuffer;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import org.eclipse.jetty.websocket.api.Session;
/** /**
* Support class for reading a (single) WebSocket TEXT message via a Reader. * Support class for reading a (single) WebSocket TEXT message via a Reader.
* <p> * <p>
@ -33,6 +35,11 @@ public class MessageReader extends InputStreamReader implements MessageAppender
{ {
private final MessageInputStream stream; private final MessageInputStream stream;
public MessageReader(Session session)
{
this(new MessageInputStream(session));
}
public MessageReader(MessageInputStream stream) public MessageReader(MessageInputStream stream)
{ {
super(stream, StandardCharsets.UTF_8); super(stream, StandardCharsets.UTF_8);
@ -50,4 +57,9 @@ public class MessageReader extends InputStreamReader implements MessageAppender
{ {
this.stream.messageComplete(); this.stream.messageComplete();
} }
public void handlerComplete()
{
this.stream.handlerComplete();
}
} }