Issue #9682 - notify WebSocket message sinks of connection close
Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
parent
a471c08717
commit
a0fe47e1e8
|
@ -89,11 +89,14 @@ public class ByteBufferCallbackAccumulator
|
|||
|
||||
public void fail(Throwable t)
|
||||
{
|
||||
for (Entry entry : _entries)
|
||||
// In some usages the callback recursively fails the accumulator.
|
||||
// So we copy and clear to avoid double completing the callback.
|
||||
ArrayList<Entry> entries = new ArrayList<>(_entries);
|
||||
_entries.clear();
|
||||
_length = 0;
|
||||
for (Entry entry : entries)
|
||||
{
|
||||
entry.callback.failed(t);
|
||||
}
|
||||
_entries.clear();
|
||||
_length = 0;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,4 +28,9 @@ public abstract class AbstractMessageSink implements MessageSink
|
|||
this.session = Objects.requireNonNull(session, "CoreSession");
|
||||
this.methodHandle = Objects.requireNonNull(methodHandle, "MethodHandle");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fail(Throwable failure)
|
||||
{
|
||||
}
|
||||
}
|
||||
|
|
|
@ -106,4 +106,11 @@ public class ByteArrayMessageSink extends AbstractMessageSink
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fail(Throwable failure)
|
||||
{
|
||||
if (out != null)
|
||||
out.fail(failure);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -113,4 +113,11 @@ public class ByteBufferMessageSink extends AbstractMessageSink
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fail(Throwable failure)
|
||||
{
|
||||
if (out != null)
|
||||
out.fail(failure);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -164,4 +164,11 @@ public abstract class DispatchedMessageSink extends AbstractMessageSink
|
|||
|
||||
typeSink.accept(frame, frameCallback);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fail(Throwable failure)
|
||||
{
|
||||
if (typeSink != null)
|
||||
typeSink.fail(failure);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -127,40 +127,6 @@ public class MessageInputStream extends InputStream implements MessageSink
|
|||
return fillLen;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException
|
||||
{
|
||||
if (LOG.isDebugEnabled())
|
||||
LOG.debug("close()");
|
||||
|
||||
ArrayList<Entry> entries = new ArrayList<>();
|
||||
try (AutoLock l = lock.lock())
|
||||
{
|
||||
if (closed)
|
||||
return;
|
||||
closed = true;
|
||||
|
||||
if (currentEntry != null)
|
||||
{
|
||||
entries.add(currentEntry);
|
||||
currentEntry = null;
|
||||
}
|
||||
|
||||
// Clear queue and fail all entries.
|
||||
entries.addAll(buffers);
|
||||
buffers.clear();
|
||||
buffers.offer(CLOSED);
|
||||
}
|
||||
|
||||
// Succeed all entries as we don't need them anymore (failing would close the connection).
|
||||
for (Entry e : entries)
|
||||
{
|
||||
e.callback.succeeded();
|
||||
}
|
||||
|
||||
super.close();
|
||||
}
|
||||
|
||||
public void setTimeout(long timeoutMs)
|
||||
{
|
||||
this.timeoutMs = timeoutMs;
|
||||
|
@ -218,6 +184,56 @@ public class MessageInputStream extends InputStream implements MessageSink
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException
|
||||
{
|
||||
fail(null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fail(Throwable failure)
|
||||
{
|
||||
if (LOG.isDebugEnabled())
|
||||
LOG.debug("close()");
|
||||
|
||||
ArrayList<Entry> entries = new ArrayList<>();
|
||||
try (AutoLock l = lock.lock())
|
||||
{
|
||||
if (closed)
|
||||
return;
|
||||
closed = true;
|
||||
|
||||
if (currentEntry != null)
|
||||
{
|
||||
entries.add(currentEntry);
|
||||
currentEntry = null;
|
||||
}
|
||||
|
||||
// Clear queue and fail all entries.
|
||||
entries.addAll(buffers);
|
||||
buffers.clear();
|
||||
buffers.offer(CLOSED);
|
||||
}
|
||||
|
||||
// Succeed all entries as we don't need them anymore (failing would close the connection).
|
||||
for (Entry e : entries)
|
||||
{
|
||||
if (failure == null)
|
||||
e.callback.succeeded();
|
||||
else
|
||||
e.callback.failed(failure);
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
super.close();
|
||||
}
|
||||
catch (IOException e)
|
||||
{
|
||||
LOG.debug("Failure Closing InputStream", e);
|
||||
}
|
||||
}
|
||||
|
||||
private static class Entry
|
||||
{
|
||||
public ByteBuffer buffer;
|
||||
|
|
|
@ -87,6 +87,12 @@ public class MessageReader extends Reader implements MessageSink
|
|||
stream.close();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fail(Throwable failure)
|
||||
{
|
||||
stream.fail(failure);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void accept(Frame frame, Callback callback)
|
||||
{
|
||||
|
|
|
@ -29,4 +29,12 @@ public interface MessageSink
|
|||
* @param callback the callback for how the frame was consumed
|
||||
*/
|
||||
void accept(Frame frame, Callback callback);
|
||||
|
||||
/**
|
||||
* <p>Fail the message sink.</p>
|
||||
* <p>Release any resources and fail all stored callbacks as {@link #accept(Frame, Callback)} will never be called again.</p>
|
||||
*
|
||||
* @param failure the failure that occurred.
|
||||
*/
|
||||
void fail(Throwable failure);
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ import org.eclipse.jetty.websocket.core.Frame;
|
|||
|
||||
public class PartialByteArrayMessageSink extends AbstractMessageSink
|
||||
{
|
||||
private static byte[] EMPTY_BUFFER = new byte[0];
|
||||
private static final byte[] EMPTY_BUFFER = new byte[0];
|
||||
|
||||
public PartialByteArrayMessageSink(CoreSession session, MethodHandle methodHandle)
|
||||
{
|
||||
|
|
|
@ -69,4 +69,11 @@ public class StringMessageSink extends AbstractMessageSink
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fail(Throwable failure)
|
||||
{
|
||||
if (out != null)
|
||||
out.reset();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,6 +38,7 @@ import org.eclipse.jetty.websocket.core.CoreSession;
|
|||
import org.eclipse.jetty.websocket.core.Frame;
|
||||
import org.eclipse.jetty.websocket.core.FrameHandler;
|
||||
import org.eclipse.jetty.websocket.core.OpCode;
|
||||
import org.eclipse.jetty.websocket.core.exception.CloseException;
|
||||
import org.eclipse.jetty.websocket.core.exception.ProtocolException;
|
||||
import org.eclipse.jetty.websocket.core.exception.WebSocketException;
|
||||
import org.eclipse.jetty.websocket.core.internal.messages.MessageSink;
|
||||
|
@ -270,6 +271,12 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
|
|||
@Override
|
||||
public void onClosed(CloseStatus closeStatus, Callback callback)
|
||||
{
|
||||
if (activeMessageSink != null)
|
||||
{
|
||||
activeMessageSink.fail(new CloseException(closeStatus.getCode(), closeStatus.getCause()));
|
||||
activeMessageSink = null;
|
||||
}
|
||||
|
||||
notifyOnClose(closeStatus, callback);
|
||||
container.notifySessionListeners((listener) -> listener.onJavaxWebSocketSessionClosed(session));
|
||||
|
||||
|
|
|
@ -80,6 +80,12 @@ public abstract class AbstractDecodedMessageSink implements MessageSink
|
|||
_messageSink.accept(frame, callback);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fail(Throwable failure)
|
||||
{
|
||||
_messageSink.fail(failure);
|
||||
}
|
||||
|
||||
public abstract static class Basic<T extends Decoder> extends AbstractDecodedMessageSink
|
||||
{
|
||||
protected final List<T> _decoders;
|
||||
|
|
|
@ -297,6 +297,12 @@ public class JettyWebSocketFrameHandler implements FrameHandler
|
|||
if (delayedCallback != null)
|
||||
delayedCallback.failed(new CloseException(closeStatus.getCode(), closeStatus.getCause()));
|
||||
|
||||
if (activeMessageSink != null)
|
||||
{
|
||||
activeMessageSink.fail(new CloseException(closeStatus.getCode(), closeStatus.getCause()));
|
||||
activeMessageSink = null;
|
||||
}
|
||||
|
||||
notifyOnClose(closeStatus, callback);
|
||||
container.notifySessionListeners((listener) -> listener.onWebSocketSessionClosed(session));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
//
|
||||
// ========================================================================
|
||||
// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others.
|
||||
//
|
||||
// This program and the accompanying materials are made available under the
|
||||
// terms of the Eclipse Public License v. 2.0 which is available at
|
||||
// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
|
||||
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
|
||||
//
|
||||
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
|
||||
// ========================================================================
|
||||
//
|
||||
|
||||
package org.eclipse.jetty.websocket.tests;
|
||||
|
||||
import java.net.URI;
|
||||
import java.time.Duration;
|
||||
import java.util.Arrays;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import org.eclipse.jetty.io.ByteBufferPool;
|
||||
import org.eclipse.jetty.io.LogarithmicArrayByteBufferPool.LogarithmicRetainablePool;
|
||||
import org.eclipse.jetty.server.Server;
|
||||
import org.eclipse.jetty.server.ServerConnector;
|
||||
import org.eclipse.jetty.servlet.ServletContextHandler;
|
||||
import org.eclipse.jetty.util.BufferUtil;
|
||||
import org.eclipse.jetty.websocket.api.Session;
|
||||
import org.eclipse.jetty.websocket.api.WebSocketAdapter;
|
||||
import org.eclipse.jetty.websocket.api.annotations.WebSocket;
|
||||
import org.eclipse.jetty.websocket.client.WebSocketClient;
|
||||
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.hamcrest.MatcherAssert.assertThat;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
public class ClientDisconnectTest
|
||||
{
|
||||
private final CompletableFuture<ServerSocket> _serverSocketFuture = new CompletableFuture<>();
|
||||
private final Duration _serverIdleTimeout = Duration.ofSeconds(5);
|
||||
private final int _messageSize = 5 * 1024 * 1024;
|
||||
private Server _server;
|
||||
private ServerConnector _connector;
|
||||
private WebSocketClient _client;
|
||||
|
||||
@WebSocket
|
||||
public class ServerSocket extends EchoSocket
|
||||
{
|
||||
@Override
|
||||
public void onOpen(Session session)
|
||||
{
|
||||
_serverSocketFuture.complete(this);
|
||||
super.onOpen(session);
|
||||
}
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
public void before() throws Exception
|
||||
{
|
||||
_client = new WebSocketClient();
|
||||
_server = new Server();
|
||||
_connector = new ServerConnector(_server);
|
||||
_server.addConnector(_connector);
|
||||
|
||||
ServletContextHandler contextHandler = new ServletContextHandler();
|
||||
JettyWebSocketServletContainerInitializer.configure(contextHandler, ((servletContext, container) ->
|
||||
{
|
||||
container.addMapping("/", (req, resp) -> new ServerSocket());
|
||||
container.setIdleTimeout(_serverIdleTimeout);
|
||||
container.setMaxBinaryMessageSize(_messageSize);
|
||||
}));
|
||||
_server.setHandler(contextHandler);
|
||||
|
||||
_server.start();
|
||||
_client.start();
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
public void after() throws Exception
|
||||
{
|
||||
_client.stop();
|
||||
_server.stop();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBuffersAfterIncompleteMessage() throws Exception
|
||||
{
|
||||
URI uri = URI.create("ws://localhost:" + _connector.getLocalPort());
|
||||
|
||||
// Open connection to the server.
|
||||
Session session = _client.connect(new WebSocketAdapter(), uri).get(5, TimeUnit.SECONDS);
|
||||
ServerSocket serverSocket = _serverSocketFuture.get(5, TimeUnit.SECONDS);
|
||||
assertNotNull(serverSocket);
|
||||
|
||||
// Send partial payload to server then abruptly close the connection.
|
||||
byte[] bytes = new byte[300_000];
|
||||
Arrays.fill(bytes, (byte)'x');
|
||||
session.setMaxBinaryMessageSize(_messageSize);
|
||||
session.getRemote().sendPartialBytes(BufferUtil.toBuffer(bytes), false);
|
||||
session.disconnect();
|
||||
|
||||
// Wait for the server to close its session.
|
||||
assertTrue(serverSocket.closeLatch.await(_serverIdleTimeout.toSeconds() + 1, TimeUnit.SECONDS));
|
||||
|
||||
// We should have no buffers still used in the pool.
|
||||
LogarithmicRetainablePool bufferPool = (LogarithmicRetainablePool)_server.getBean(ByteBufferPool.class).asRetainableByteBufferPool();
|
||||
assertThat(bufferPool.getDirectByteBufferCount() - bufferPool.getAvailableDirectByteBufferCount(), equalTo(0L));
|
||||
assertThat(bufferPool.getHeapByteBufferCount() - bufferPool.getAvailableHeapByteBufferCount(), equalTo(0L));
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue