Issue #9682 - notify WebSocket message sinks of connection close

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
Lachlan Roberts 2023-06-07 15:45:16 +10:00
parent a471c08717
commit a0fe47e1e8
14 changed files with 238 additions and 38 deletions

View File

@ -89,11 +89,14 @@ public class ByteBufferCallbackAccumulator
public void fail(Throwable t)
{
for (Entry entry : _entries)
// In some usages the callback recursively fails the accumulator.
// So we copy and clear to avoid double completing the callback.
ArrayList<Entry> entries = new ArrayList<>(_entries);
_entries.clear();
_length = 0;
for (Entry entry : entries)
{
entry.callback.failed(t);
}
_entries.clear();
_length = 0;
}
}

View File

@ -28,4 +28,9 @@ public abstract class AbstractMessageSink implements MessageSink
this.session = Objects.requireNonNull(session, "CoreSession");
this.methodHandle = Objects.requireNonNull(methodHandle, "MethodHandle");
}
@Override
public void fail(Throwable failure)
{
}
}

View File

@ -106,4 +106,11 @@ public class ByteArrayMessageSink extends AbstractMessageSink
}
}
}
@Override
public void fail(Throwable failure)
{
if (out != null)
out.fail(failure);
}
}

View File

@ -113,4 +113,11 @@ public class ByteBufferMessageSink extends AbstractMessageSink
}
}
}
@Override
public void fail(Throwable failure)
{
if (out != null)
out.fail(failure);
}
}

View File

@ -164,4 +164,11 @@ public abstract class DispatchedMessageSink extends AbstractMessageSink
typeSink.accept(frame, frameCallback);
}
@Override
public void fail(Throwable failure)
{
if (typeSink != null)
typeSink.fail(failure);
}
}

View File

@ -127,40 +127,6 @@ public class MessageInputStream extends InputStream implements MessageSink
return fillLen;
}
@Override
public void close() throws IOException
{
if (LOG.isDebugEnabled())
LOG.debug("close()");
ArrayList<Entry> entries = new ArrayList<>();
try (AutoLock l = lock.lock())
{
if (closed)
return;
closed = true;
if (currentEntry != null)
{
entries.add(currentEntry);
currentEntry = null;
}
// Clear queue and fail all entries.
entries.addAll(buffers);
buffers.clear();
buffers.offer(CLOSED);
}
// Succeed all entries as we don't need them anymore (failing would close the connection).
for (Entry e : entries)
{
e.callback.succeeded();
}
super.close();
}
public void setTimeout(long timeoutMs)
{
this.timeoutMs = timeoutMs;
@ -218,6 +184,56 @@ public class MessageInputStream extends InputStream implements MessageSink
}
}
@Override
public void close() throws IOException
{
fail(null);
}
@Override
public void fail(Throwable failure)
{
if (LOG.isDebugEnabled())
LOG.debug("close()");
ArrayList<Entry> entries = new ArrayList<>();
try (AutoLock l = lock.lock())
{
if (closed)
return;
closed = true;
if (currentEntry != null)
{
entries.add(currentEntry);
currentEntry = null;
}
// Clear queue and fail all entries.
entries.addAll(buffers);
buffers.clear();
buffers.offer(CLOSED);
}
// Succeed all entries as we don't need them anymore (failing would close the connection).
for (Entry e : entries)
{
if (failure == null)
e.callback.succeeded();
else
e.callback.failed(failure);
}
try
{
super.close();
}
catch (IOException e)
{
LOG.debug("Failure Closing InputStream", e);
}
}
private static class Entry
{
public ByteBuffer buffer;

View File

@ -87,6 +87,12 @@ public class MessageReader extends Reader implements MessageSink
stream.close();
}
@Override
public void fail(Throwable failure)
{
stream.fail(failure);
}
@Override
public void accept(Frame frame, Callback callback)
{

View File

@ -29,4 +29,12 @@ public interface MessageSink
* @param callback the callback for how the frame was consumed
*/
void accept(Frame frame, Callback callback);
/**
* <p>Fail the message sink.</p>
* <p>Release any resources and fail all stored callbacks as {@link #accept(Frame, Callback)} will never be called again.</p>
*
* @param failure the failure that occurred.
*/
void fail(Throwable failure);
}

View File

@ -22,7 +22,7 @@ import org.eclipse.jetty.websocket.core.Frame;
public class PartialByteArrayMessageSink extends AbstractMessageSink
{
private static byte[] EMPTY_BUFFER = new byte[0];
private static final byte[] EMPTY_BUFFER = new byte[0];
public PartialByteArrayMessageSink(CoreSession session, MethodHandle methodHandle)
{

View File

@ -69,4 +69,11 @@ public class StringMessageSink extends AbstractMessageSink
}
}
}
@Override
public void fail(Throwable failure)
{
if (out != null)
out.reset();
}
}

View File

@ -38,6 +38,7 @@ import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.Frame;
import org.eclipse.jetty.websocket.core.FrameHandler;
import org.eclipse.jetty.websocket.core.OpCode;
import org.eclipse.jetty.websocket.core.exception.CloseException;
import org.eclipse.jetty.websocket.core.exception.ProtocolException;
import org.eclipse.jetty.websocket.core.exception.WebSocketException;
import org.eclipse.jetty.websocket.core.internal.messages.MessageSink;
@ -270,6 +271,12 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
@Override
public void onClosed(CloseStatus closeStatus, Callback callback)
{
if (activeMessageSink != null)
{
activeMessageSink.fail(new CloseException(closeStatus.getCode(), closeStatus.getCause()));
activeMessageSink = null;
}
notifyOnClose(closeStatus, callback);
container.notifySessionListeners((listener) -> listener.onJavaxWebSocketSessionClosed(session));

View File

@ -80,6 +80,12 @@ public abstract class AbstractDecodedMessageSink implements MessageSink
_messageSink.accept(frame, callback);
}
@Override
public void fail(Throwable failure)
{
_messageSink.fail(failure);
}
public abstract static class Basic<T extends Decoder> extends AbstractDecodedMessageSink
{
protected final List<T> _decoders;

View File

@ -297,6 +297,12 @@ public class JettyWebSocketFrameHandler implements FrameHandler
if (delayedCallback != null)
delayedCallback.failed(new CloseException(closeStatus.getCode(), closeStatus.getCause()));
if (activeMessageSink != null)
{
activeMessageSink.fail(new CloseException(closeStatus.getCode(), closeStatus.getCause()));
activeMessageSink = null;
}
notifyOnClose(closeStatus, callback);
container.notifySessionListeners((listener) -> listener.onWebSocketSessionClosed(session));
}

View File

@ -0,0 +1,115 @@
//
// ========================================================================
// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License v. 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.websocket.tests;
import java.net.URI;
import java.time.Duration;
import java.util.Arrays;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.io.ByteBufferPool;
import org.eclipse.jetty.io.LogarithmicArrayByteBufferPool.LogarithmicRetainablePool;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketAdapter;
import org.eclipse.jetty.websocket.api.annotations.WebSocket;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class ClientDisconnectTest
{
private final CompletableFuture<ServerSocket> _serverSocketFuture = new CompletableFuture<>();
private final Duration _serverIdleTimeout = Duration.ofSeconds(5);
private final int _messageSize = 5 * 1024 * 1024;
private Server _server;
private ServerConnector _connector;
private WebSocketClient _client;
@WebSocket
public class ServerSocket extends EchoSocket
{
@Override
public void onOpen(Session session)
{
_serverSocketFuture.complete(this);
super.onOpen(session);
}
}
@BeforeEach
public void before() throws Exception
{
_client = new WebSocketClient();
_server = new Server();
_connector = new ServerConnector(_server);
_server.addConnector(_connector);
ServletContextHandler contextHandler = new ServletContextHandler();
JettyWebSocketServletContainerInitializer.configure(contextHandler, ((servletContext, container) ->
{
container.addMapping("/", (req, resp) -> new ServerSocket());
container.setIdleTimeout(_serverIdleTimeout);
container.setMaxBinaryMessageSize(_messageSize);
}));
_server.setHandler(contextHandler);
_server.start();
_client.start();
}
@AfterEach
public void after() throws Exception
{
_client.stop();
_server.stop();
}
@Test
public void testBuffersAfterIncompleteMessage() throws Exception
{
URI uri = URI.create("ws://localhost:" + _connector.getLocalPort());
// Open connection to the server.
Session session = _client.connect(new WebSocketAdapter(), uri).get(5, TimeUnit.SECONDS);
ServerSocket serverSocket = _serverSocketFuture.get(5, TimeUnit.SECONDS);
assertNotNull(serverSocket);
// Send partial payload to server then abruptly close the connection.
byte[] bytes = new byte[300_000];
Arrays.fill(bytes, (byte)'x');
session.setMaxBinaryMessageSize(_messageSize);
session.getRemote().sendPartialBytes(BufferUtil.toBuffer(bytes), false);
session.disconnect();
// Wait for the server to close its session.
assertTrue(serverSocket.closeLatch.await(_serverIdleTimeout.toSeconds() + 1, TimeUnit.SECONDS));
// We should have no buffers still used in the pool.
LogarithmicRetainablePool bufferPool = (LogarithmicRetainablePool)_server.getBean(ByteBufferPool.class).asRetainableByteBufferPool();
assertThat(bufferPool.getDirectByteBufferCount() - bufferPool.getAvailableDirectByteBufferCount(), equalTo(0L));
assertThat(bufferPool.getHeapByteBufferCount() - bufferPool.getAvailableHeapByteBufferCount(), equalTo(0L));
}
}