Issue #6566 - utilise the demand interface in the websocket MessageSinks

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
Lachlan Roberts 2021-08-18 11:01:35 +10:00
parent d3dfe353be
commit 5236e47c42
11 changed files with 94 additions and 64 deletions

View File

@ -24,30 +24,17 @@ import org.eclipse.jetty.util.Callback;
public class BufferCallbackAccumulator
{
private final List<Entry> _entries = new ArrayList<>();
private final ByteBufferPool _bufferPool;
private final boolean _direct;
private static class Entry
{
private final ByteBuffer buffer;
private final Callback callback;
Entry(ByteBuffer buffer, Callback callback)
{
this.buffer = buffer;
this.callback = callback;
}
ByteBuffer buffer;
Callback callback;
}
public BufferCallbackAccumulator()
{
this(null, false);
}
BufferCallbackAccumulator(ByteBufferPool bufferPool, boolean direct)
{
_bufferPool = (bufferPool == null) ? new NullByteBufferPool() : bufferPool;
_direct = direct;
}
public void addEntry(ByteBuffer buffer, Callback callback)
@ -86,7 +73,7 @@ public class BufferCallbackAccumulator
public void writeTo(ByteBuffer buffer)
{
for (Iterator<Entry> iterator = _entries.iterator(); iterator.hasNext(); )
for (Iterator<Entry> iterator = _entries.iterator(); iterator.hasNext();)
{
Entry entry = iterator.next();
buffer.put(entry.buffer);

View File

@ -13,12 +13,11 @@
package org.eclipse.jetty.websocket.core.internal.messages;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodType;
import java.nio.ByteBuffer;
import org.eclipse.jetty.io.BufferCallbackAccumulator;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.websocket.core.CoreSession;
@ -29,8 +28,7 @@ import org.eclipse.jetty.websocket.core.exception.MessageTooLargeException;
public class ByteArrayMessageSink extends AbstractMessageSink
{
private static final byte[] EMPTY_BUFFER = new byte[0];
private static final int BUFFER_SIZE = 65535;
private ByteArrayOutputStream out;
private BufferCallbackAccumulator out;
private int size;
public ByteArrayMessageSink(CoreSession session, MethodHandle methodHandle)
@ -55,8 +53,8 @@ public class ByteArrayMessageSink extends AbstractMessageSink
long maxBinaryMessageSize = session.getMaxBinaryMessageSize();
if (maxBinaryMessageSize > 0 && size > maxBinaryMessageSize)
{
throw new MessageTooLargeException(String.format("Binary message too large: (actual) %,d > (configured max binary message size) %,d",
size, maxBinaryMessageSize));
callback.failed(new MessageTooLargeException(
String.format("Binary message too large: (actual) %,d > (configured max binary message size) %,d", size, maxBinaryMessageSize)));
}
// If we are fin and no OutputStream has been created we don't need to aggregate.
@ -71,19 +69,26 @@ public class ByteArrayMessageSink extends AbstractMessageSink
methodHandle.invoke(EMPTY_BUFFER, 0, 0);
callback.succeeded();
session.demand(1);
return;
}
aggregatePayload(frame);
aggregatePayload(frame, callback);
// If the methodHandle throws we don't want to fail callback twice.
callback = Callback.NOOP;
if (frame.isFin())
{
byte[] buf = out.toByteArray();
methodHandle.invoke(buf, 0, buf.length);
}
callback.succeeded();
session.demand(1);
}
catch (Throwable t)
{
if (out != null)
out.fail(t);
callback.failed(t);
}
finally
@ -97,14 +102,14 @@ public class ByteArrayMessageSink extends AbstractMessageSink
}
}
private void aggregatePayload(Frame frame) throws IOException
private void aggregatePayload(Frame frame, Callback callback)
{
if (frame.hasPayload())
{
ByteBuffer payload = frame.getPayload();
if (out == null)
out = new ByteArrayOutputStream(BUFFER_SIZE);
BufferUtil.writeTo(payload, out);
out = new BufferCallbackAccumulator();
out.addEntry(payload, callback);
}
}
}

View File

@ -13,13 +13,13 @@
package org.eclipse.jetty.websocket.core.internal.messages;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodType;
import java.nio.ByteBuffer;
import java.util.Objects;
import org.eclipse.jetty.io.BufferCallbackAccumulator;
import org.eclipse.jetty.io.ByteBufferPool;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.websocket.core.CoreSession;
@ -29,8 +29,7 @@ import org.eclipse.jetty.websocket.core.exception.MessageTooLargeException;
public class ByteBufferMessageSink extends AbstractMessageSink
{
private static final int BUFFER_SIZE = 65535;
private ByteArrayOutputStream out;
private BufferCallbackAccumulator out;
private int size;
public ByteBufferMessageSink(CoreSession session, MethodHandle methodHandle)
@ -68,41 +67,58 @@ public class ByteBufferMessageSink extends AbstractMessageSink
methodHandle.invoke(BufferUtil.EMPTY_BUFFER);
callback.succeeded();
session.demand(1);
return;
}
aggregatePayload(frame);
if (frame.isFin())
methodHandle.invoke(ByteBuffer.wrap(out.toByteArray()));
aggregatePayload(frame, callback);
callback.succeeded();
// If the methodHandle throws we don't want to fail callback twice.
callback = Callback.NOOP;
if (frame.isFin())
{
ByteBufferPool bufferPool = session.getByteBufferPool();
ByteBuffer buffer = bufferPool.acquire(out.getLength(), false);
BufferUtil.clearToFill(buffer);
out.writeTo(buffer);
BufferUtil.flipToFlush(buffer, 0);
try
{
methodHandle.invoke(buffer);
}
finally
{
bufferPool.release(buffer);
}
session.demand(1);
}
}
catch (Throwable t)
{
if (out != null)
out.fail(t);
callback.failed(t);
}
finally
{
if (frame.isFin())
{
// reset
out = null;
size = 0;
}
}
}
private void aggregatePayload(Frame frame) throws IOException
private void aggregatePayload(Frame frame, Callback callback)
{
if (frame.hasPayload())
{
ByteBuffer payload = frame.getPayload();
if (out == null)
out = new ByteArrayOutputStream(BUFFER_SIZE);
BufferUtil.writeTo(payload, out);
payload.position(payload.limit()); // consume buffer
out = new BufferCallbackAccumulator();
out.addEntry(payload, callback);
}
}
}

View File

@ -135,22 +135,28 @@ public abstract class DispatchedMessageSink extends AbstractMessageSink
});
}
Callback frameCallback = callback;
Callback frameCallback;
if (frame.isFin())
{
// This is the final frame we should wait for the frame callback and the dispatched thread.
Callback.Completable completableCallback = new Callback.Completable();
frameCallback = completableCallback;
CompletableFuture.allOf(dispatchComplete, completableCallback).whenComplete((aVoid, throwable) ->
Callback.Completable finComplete = Callback.Completable.from(callback);
frameCallback = finComplete;
CompletableFuture.allOf(dispatchComplete, finComplete).whenComplete((aVoid, throwable) ->
{
typeSink = null;
dispatchComplete = null;
if (throwable != null)
callback.failed(throwable);
else
callback.succeeded();
if (throwable == null)
session.demand(1);
});
}
else
{
frameCallback = Callback.from(() ->
{
callback.succeeded();
session.demand(1);
}, callback::failed);
}
typeSink.accept(frame, frameCallback);
}

View File

@ -41,6 +41,7 @@ public class PartialByteArrayMessageSink extends AbstractMessageSink
}
callback.succeeded();
session.demand(1);
}
catch (Throwable t)
{

View File

@ -35,6 +35,7 @@ public class PartialByteBufferMessageSink extends AbstractMessageSink
methodHandle.invoke(frame.getPayload(), frame.isFin());
callback.succeeded();
session.demand(1);
}
catch (Throwable t)
{

View File

@ -51,6 +51,7 @@ public class PartialStringMessageSink extends AbstractMessageSink
}
callback.succeeded();
session.demand(1);
}
catch (Throwable t)
{

View File

@ -53,6 +53,7 @@ public class StringMessageSink extends AbstractMessageSink
methodHandle.invoke(out.toString());
callback.succeeded();
session.demand(1);
}
catch (Throwable t)
{

View File

@ -178,6 +178,7 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
container.notifySessionListeners((listener) -> listener.onJavaxWebSocketSessionOpened(session));
callback.succeeded();
coreSession.demand(1);
}
catch (Throwable cause)
{
@ -321,6 +322,12 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
}
}
@Override
public boolean isDemanding()
{
return true;
}
public Set<MessageHandler> getMessageHandlers()
{
return messageHandlerMap.values().stream()
@ -591,6 +598,7 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
ByteBuffer payload = BufferUtil.copy(frame.getPayload());
coreSession.sendFrame(new Frame(OpCode.PONG).setPayload(payload), Callback.NOOP, false);
callback.succeeded();
coreSession.demand(1);
}
public void onPong(Frame frame, Callback callback)
@ -613,6 +621,7 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
}
}
callback.succeeded();
coreSession.demand(1);
}
public void onText(Frame frame, Callback callback)

View File

@ -17,6 +17,8 @@ import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.Session;
import org.eclipse.jetty.io.ByteBufferPool;
import org.eclipse.jetty.io.NullByteBufferPool;
import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.WebSocketComponents;
import org.junit.jupiter.api.AfterAll;
@ -38,6 +40,7 @@ public abstract class AbstractSessionTest
Object websocketPojo = new DummyEndpoint();
UpgradeRequest upgradeRequest = new UpgradeRequestAdapter();
JavaxWebSocketFrameHandler frameHandler = container.newFrameHandler(websocketPojo, upgradeRequest);
ByteBufferPool bufferPool = new NullByteBufferPool();
CoreSession coreSession = new CoreSession.Empty()
{
@Override
@ -45,6 +48,12 @@ public abstract class AbstractSessionTest
{
return components;
}
@Override
public ByteBufferPool getByteBufferPool()
{
return bufferPool;
}
};
session = new JavaxWebSocketSession(container, coreSession, frameHandler, container.getFrameHandlerFactory()
.newDefaultEndpointConfig(websocketPojo.getClass()));

View File

@ -82,6 +82,7 @@ public class JettyWebSocketFrameHandler implements FrameHandler
private WebSocketSession session;
private SuspendState state = SuspendState.DEMANDING;
private Runnable delayedOnFrame;
private CoreSession coreSession;
public JettyWebSocketFrameHandler(WebSocketContainer container,
Object endpointInstance,
@ -150,6 +151,7 @@ public class JettyWebSocketFrameHandler implements FrameHandler
try
{
customizer.customize(coreSession);
this.coreSession = coreSession;
session = new WebSocketSession(container, coreSession, this);
if (!session.isOpen())
throw new IllegalStateException("Session is not open");
@ -226,16 +228,7 @@ public class JettyWebSocketFrameHandler implements FrameHandler
// Demand after succeeding any received frame
Callback demandingCallback = Callback.from(() ->
{
try
{
demand();
}
catch (Throwable t)
{
callback.failed(t);
return;
}
demand();
callback.succeeded();
},
callback::failed
@ -253,13 +246,13 @@ public class JettyWebSocketFrameHandler implements FrameHandler
onPongFrame(frame, demandingCallback);
break;
case OpCode.TEXT:
onTextFrame(frame, demandingCallback);
onTextFrame(frame, callback);
break;
case OpCode.BINARY:
onBinaryFrame(frame, demandingCallback);
onBinaryFrame(frame, callback);
break;
case OpCode.CONTINUATION:
onContinuationFrame(frame, demandingCallback);
onContinuationFrame(frame, callback);
break;
default:
callback.failed(new IllegalStateException());
@ -342,6 +335,7 @@ public class JettyWebSocketFrameHandler implements FrameHandler
if (activeMessageSink == null)
{
callback.succeeded();
coreSession.demand(1);
return;
}