Issue #4475 - suspend/resume to control reading frames while streaming

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
Lachlan Roberts 2020-01-17 13:11:48 +11:00
parent c7b6ccca98
commit 3fd7094c01
6 changed files with 276 additions and 141 deletions

View File

@ -116,27 +116,24 @@ public class JsrAnnotatedEventDriver extends AbstractJsrEventDriver
if (activeMessage == null)
{
if (LOG.isDebugEnabled())
{
LOG.debug("Binary Message InputStream");
}
final MessageInputStream stream = new MessageInputStream();
final MessageInputStream stream = new MessageInputStream(session);
activeMessage = stream;
// Always dispatch streaming read to another thread.
dispatch(new Runnable()
dispatch(() ->
{
@Override
public void run()
try
{
try
{
events.callBinaryStream(jsrsession.getAsyncRemote(), websocket, stream);
}
catch (Throwable e)
{
onFatalError(e);
}
events.callBinaryStream(jsrsession.getAsyncRemote(), websocket, stream);
}
catch (Throwable e)
{
session.close(e);
}
stream.close();
});
}
}
@ -330,28 +327,25 @@ public class JsrAnnotatedEventDriver extends AbstractJsrEventDriver
if (activeMessage == null)
{
if (LOG.isDebugEnabled())
{
LOG.debug("Text Message Writer");
}
final MessageReader stream = new MessageReader(new MessageInputStream());
activeMessage = stream;
MessageInputStream inputStream = new MessageInputStream(session);
final MessageReader reader = new MessageReader(inputStream);
activeMessage = inputStream;
// Always dispatch streaming read to another thread.
dispatch(new Runnable()
dispatch(() ->
{
@Override
public void run()
try
{
try
{
events.callTextStream(jsrsession.getAsyncRemote(), websocket, stream);
}
catch (Throwable e)
{
onFatalError(e);
}
events.callTextStream(jsrsession.getAsyncRemote(), websocket, reader);
}
catch (Throwable e)
{
session.close(e);
}
inputStream.close();
});
}
}

View File

@ -23,7 +23,6 @@ import java.io.InputStream;
import java.io.Reader;
import java.nio.ByteBuffer;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import javax.websocket.CloseReason;
import javax.websocket.Endpoint;
import javax.websocket.MessageHandler;
@ -88,17 +87,22 @@ public class JsrEndpointEventDriver extends AbstractJsrEventDriver
}
else if (wrapper.wantsStreams())
{
final MessageInputStream stream = new MessageInputStream();
activeMessage = stream;
dispatch(new Runnable()
@SuppressWarnings("unchecked")
MessageHandler.Whole<InputStream> handler = (Whole<InputStream>)wrapper.getHandler();
MessageInputStream inputStream = new MessageInputStream(session);
activeMessage = inputStream;
dispatch(() ->
{
@SuppressWarnings("unchecked")
@Override
public void run()
try
{
MessageHandler.Whole<InputStream> handler = (Whole<InputStream>)wrapper.getHandler();
handler.onMessage(stream);
handler.onMessage(inputStream);
}
catch (Throwable t)
{
session.close(t);
}
inputStream.close();
});
}
else
@ -191,35 +195,23 @@ public class JsrEndpointEventDriver extends AbstractJsrEventDriver
}
else if (wrapper.wantsStreams())
{
final CountDownLatch completed = new CountDownLatch(1);
final MessageReader stream = new MessageReader(new MessageInputStream())
@SuppressWarnings("unchecked")
MessageHandler.Whole<Reader> handler = (Whole<Reader>)wrapper.getHandler();
MessageInputStream inputStream = new MessageInputStream(session);
MessageReader reader = new MessageReader(inputStream);
activeMessage = reader;
dispatch(() ->
{
@Override
public void messageComplete()
try
{
super.messageComplete();
try
{
completed.await();
}
catch (Exception e)
{
throw new RuntimeException(e);
}
handler.onMessage(reader);
}
catch (Throwable t)
{
session.close(t);
}
};
activeMessage = stream;
dispatch(new Runnable()
{
@SuppressWarnings("unchecked")
@Override
public void run()
{
MessageHandler.Whole<Reader> handler = (Whole<Reader>)wrapper.getHandler();
handler.onMessage(stream);
completed.countDown();
}
inputStream.close();
});
}
else

View File

@ -97,23 +97,21 @@ public class JettyAnnotatedEventDriver extends AbstractEventDriver
{
if (events.onBinary.isStreaming())
{
activeMessage = new MessageInputStream();
final MessageAppender msg = activeMessage;
dispatch(new Runnable()
final MessageInputStream inputStream = new MessageInputStream(session);
activeMessage = inputStream;
dispatch(() ->
{
@Override
public void run()
try
{
try
{
events.onBinary.call(websocket, session, msg);
}
catch (Throwable t)
{
// dispatched calls need to be reported
onError(t);
}
events.onBinary.call(websocket, session, inputStream);
}
catch (Throwable t)
{
// dispatched calls need to be reported
session.close(t);
}
inputStream.close();
});
}
else
@ -215,23 +213,22 @@ public class JettyAnnotatedEventDriver extends AbstractEventDriver
{
if (events.onText.isStreaming())
{
activeMessage = new MessageReader(new MessageInputStream());
MessageInputStream inputStream = new MessageInputStream(session);
activeMessage = new MessageReader(inputStream);
final MessageAppender msg = activeMessage;
dispatch(new Runnable()
dispatch(() ->
{
@Override
public void run()
try
{
try
{
events.onText.call(websocket, session, msg);
}
catch (Throwable t)
{
// dispatched calls need to be reported
onError(t);
}
events.onText.call(websocket, session, msg);
}
catch (Throwable t)
{
// dispatched calls need to be reported
session.close(t);
}
inputStream.close();
});
}
else

View File

@ -24,11 +24,14 @@ import java.nio.ByteBuffer;
import java.util.concurrent.BlockingDeque;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.eclipse.jetty.io.ByteBufferPool;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.SuspendToken;
import org.eclipse.jetty.websocket.common.WebSocketSession;
/**
* Support class for reading a (single) WebSocket BINARY message via a InputStream.
@ -40,63 +43,59 @@ public class MessageInputStream extends InputStream implements MessageAppender
private static final Logger LOG = Log.getLogger(MessageInputStream.class);
private static final ByteBuffer EOF = ByteBuffer.allocate(0).asReadOnlyBuffer();
private final Session session;
private final ByteBufferPool bufferPool;
private final BlockingDeque<ByteBuffer> buffers = new LinkedBlockingDeque<>();
private AtomicBoolean closed = new AtomicBoolean(false);
private final long timeoutMs;
private ByteBuffer activeBuffer = null;
private volatile boolean closed = false;
private volatile SuspendToken suspendToken;
private static boolean isTheEofBuffer(ByteBuffer buf)
public MessageInputStream(Session session)
{
@SuppressWarnings("ReferenceEquality")
boolean isTheEofBuffer = (buf == EOF);
return isTheEofBuffer;
this(session, -1);
}
public MessageInputStream()
{
this(-1);
}
public MessageInputStream(int timeoutMs)
public MessageInputStream(Session session, int timeoutMs)
{
this.timeoutMs = timeoutMs;
this.session = session;
this.bufferPool = (session instanceof WebSocketSession) ? ((WebSocketSession)session).getBufferPool() : null;
this.suspendToken = session.suspend();
}
@Override
public void appendFrame(ByteBuffer framePayload, boolean fin) throws IOException
{
if (LOG.isDebugEnabled())
{
LOG.debug("Appending {} chunk: {}", fin ? "final" : "non-final", BufferUtil.toDetailString(framePayload));
}
// If closed, we should just toss incoming payloads into the bit bucket.
if (closed.get())
{
if (closed)
return;
}
// Put the payload into the queue, by copying it.
// Copying is necessary because the payload will
// be processed after this method returns.
try
{
if (framePayload == null)
{
// skip if no payload
if (framePayload == null || !framePayload.hasRemaining())
return;
}
int capacity = framePayload.remaining();
if (capacity <= 0)
ByteBuffer copy = acquire(framePayload.remaining(), framePayload.isDirect());
BufferUtil.clearToFill(copy);
copy.put(framePayload);
BufferUtil.flipToFlush(copy, 0);
synchronized (this)
{
// skip if no payload data to copy
return;
if (closed)
return;
if (suspendToken == null)
suspendToken = session.suspend();
buffers.put(copy);
}
// TODO: the copy buffer should be pooled too, but no buffer pool available from here.
ByteBuffer copy = framePayload.isDirect() ? ByteBuffer.allocateDirect(capacity) : ByteBuffer.allocate(capacity);
copy.put(framePayload).flip();
buffers.put(copy);
}
catch (InterruptedException e)
{
@ -105,20 +104,32 @@ public class MessageInputStream extends InputStream implements MessageAppender
finally
{
if (fin)
{
buffers.offer(EOF);
}
}
}
@Override
public void close() throws IOException
private ByteBuffer acquire(int capacity, boolean direct)
{
if (closed.compareAndSet(false, true))
ByteBuffer buffer;
if (bufferPool != null)
buffer = bufferPool.acquire(capacity, direct);
else
buffer = direct ? BufferUtil.allocateDirect(capacity) : BufferUtil.allocate(capacity);
return buffer;
}
@Override
public void close()
{
synchronized (this)
{
closed = true;
buffers.clear();
buffers.offer(EOF);
super.close();
}
// Resume to discard util we reach next message.
resume();
}
@Override
@ -146,7 +157,7 @@ public class MessageInputStream extends InputStream implements MessageAppender
{
try
{
if (closed.get())
if (closed)
{
if (LOG.isDebugEnabled())
LOG.debug("Stream closed");
@ -168,34 +179,46 @@ public class MessageInputStream extends InputStream implements MessageAppender
// Wait at most for the given timeout.
activeBuffer = buffers.poll(timeoutMs, TimeUnit.MILLISECONDS);
if (activeBuffer == null)
{
throw new IOException(String.format("Read timeout: %,dms expired", timeoutMs));
}
}
if (isTheEofBuffer(activeBuffer))
if (activeBuffer == EOF)
{
if (LOG.isDebugEnabled())
LOG.debug("Reached EOF");
// Be sure that this stream cannot be reused.
closed.set(true);
// Removed buffers that may have remained in the queue.
buffers.clear();
close();
return -1;
}
}
return activeBuffer.get() & 0xFF;
int result = activeBuffer.get() & 0xFF;
if (!activeBuffer.hasRemaining())
resume();
return result;
}
catch (InterruptedException x)
{
if (LOG.isDebugEnabled())
LOG.debug("Interrupted while waiting to read", x);
closed.set(true);
close();
return -1;
}
}
private void resume()
{
SuspendToken resume;
synchronized (this)
{
resume = suspendToken;
suspendToken = null;
}
if (resume != null)
resume.resume();
}
@Override
public void reset() throws IOException
{

View File

@ -0,0 +1,129 @@
//
// ========================================================================
// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
//
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
//
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
//
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
//
package org.eclipse.jetty.websocket.common.message;
import java.io.IOException;
import java.net.InetSocketAddress;
import org.eclipse.jetty.websocket.api.CloseStatus;
import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.SuspendToken;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.api.UpgradeResponse;
import org.eclipse.jetty.websocket.api.WebSocketPolicy;
public class EmptySession implements Session, SuspendToken
{
@Override
public void close()
{
}
@Override
public void close(CloseStatus closeStatus)
{
}
@Override
public void close(int statusCode, String reason)
{
}
@Override
public void disconnect() throws IOException
{
}
@Override
public long getIdleTimeout()
{
return -1;
}
@Override
public InetSocketAddress getLocalAddress()
{
return null;
}
@Override
public WebSocketPolicy getPolicy()
{
return null;
}
@Override
public String getProtocolVersion()
{
return null;
}
@Override
public RemoteEndpoint getRemote()
{
return null;
}
@Override
public InetSocketAddress getRemoteAddress()
{
return null;
}
@Override
public UpgradeRequest getUpgradeRequest()
{
return null;
}
@Override
public UpgradeResponse getUpgradeResponse()
{
return null;
}
@Override
public boolean isOpen()
{
return false;
}
@Override
public boolean isSecure()
{
return false;
}
@Override
public void setIdleTimeout(long ms)
{
}
@Override
public SuspendToken suspend()
{
return this;
}
@Override
public void resume()
{
}
}

View File

@ -48,7 +48,7 @@ public class MessageInputStreamTest
@Test
public void testBasicAppendRead() throws IOException
{
try (MessageInputStream stream = new MessageInputStream())
try (MessageInputStream stream = new MessageInputStream(new EmptySession()))
{
Assertions.assertTimeoutPreemptively(ofSeconds(5), () ->
{
@ -71,7 +71,7 @@ public class MessageInputStreamTest
@Test
public void testBlockOnRead() throws Exception
{
try (MessageInputStream stream = new MessageInputStream())
try (MessageInputStream stream = new MessageInputStream(new EmptySession()))
{
final AtomicBoolean hadError = new AtomicBoolean(false);
final CountDownLatch startLatch = new CountDownLatch(1);
@ -123,7 +123,7 @@ public class MessageInputStreamTest
@Test
public void testBlockOnReadInitial() throws IOException
{
try (MessageInputStream stream = new MessageInputStream())
try (MessageInputStream stream = new MessageInputStream(new EmptySession()))
{
final AtomicBoolean hadError = new AtomicBoolean(false);
@ -163,7 +163,7 @@ public class MessageInputStreamTest
@Test
public void testReadByteNoBuffersClosed() throws IOException
{
try (MessageInputStream stream = new MessageInputStream())
try (MessageInputStream stream = new MessageInputStream(new EmptySession()))
{
final AtomicBoolean hadError = new AtomicBoolean(false);
@ -202,7 +202,7 @@ public class MessageInputStreamTest
@Test
public void testAppendEmptyPayloadRead() throws IOException
{
try (MessageInputStream stream = new MessageInputStream())
try (MessageInputStream stream = new MessageInputStream(new EmptySession()))
{
Assertions.assertTimeoutPreemptively(ofSeconds(10), () ->
{
@ -229,7 +229,7 @@ public class MessageInputStreamTest
@Test
public void testAppendNullPayloadRead() throws IOException
{
try (MessageInputStream stream = new MessageInputStream())
try (MessageInputStream stream = new MessageInputStream(new EmptySession()))
{
Assertions.assertTimeoutPreemptively(ofSeconds(10), () ->
{