Improve cleanup of deflater/inflater pools for PerMessageDeflateExtension

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
Lachlan Roberts 2022-06-07 18:52:48 +10:00
parent c34483e52b
commit 5a24f90064
8 changed files with 285 additions and 10 deletions

View File

@ -16,9 +16,11 @@ package org.eclipse.jetty.util.compression;
import java.io.Closeable;
import org.eclipse.jetty.util.Pool;
import org.eclipse.jetty.util.component.AbstractLifeCycle;
import org.eclipse.jetty.util.annotation.ManagedObject;
import org.eclipse.jetty.util.component.ContainerLifeCycle;
public abstract class CompressionPool<T> extends AbstractLifeCycle
@ManagedObject
public abstract class CompressionPool<T> extends ContainerLifeCycle
{
public static final int DEFAULT_CAPACITY = 1024;
@ -51,6 +53,11 @@ public abstract class CompressionPool<T> extends AbstractLifeCycle
_capacity = capacity;
}
public Pool<Entry> getPool()
{
return _pool;
}
protected abstract T newPooled();
protected abstract void end(T object);
@ -85,7 +92,10 @@ public abstract class CompressionPool<T> extends AbstractLifeCycle
protected void doStart() throws Exception
{
if (_capacity > 0)
{
_pool = new Pool<>(Pool.StrategyType.RANDOM, _capacity, true);
addBean(_pool);
}
super.doStart();
}
@ -95,6 +105,7 @@ public abstract class CompressionPool<T> extends AbstractLifeCycle
if (_pool != null)
{
_pool.close();
removeBean(_pool);
_pool = null;
}
super.doStop();

View File

@ -23,6 +23,13 @@ public interface Extension extends IncomingFrames, OutgoingFrames
void init(ExtensionConfig config, WebSocketComponents components);
/**
* Used to clean up any resources after connection close.
*/
default void close()
{
}
/**
* The active configuration for this extension.
*

View File

@ -60,6 +60,14 @@ public class ExtensionStack implements IncomingFrames, OutgoingFrames, Dumpable
this.behavior = behavior;
}
public void close()
{
for (Extension e : extensions)
{
e.close();
}
}
@ManagedAttribute(name = "Extension List", readonly = true)
public List<Extension> getExtensions()
{

View File

@ -44,7 +44,6 @@ public class FrameFlusher extends IteratingCallback
{
public static final Frame FLUSH_FRAME = new Frame(OpCode.BINARY);
private static final Logger LOG = LoggerFactory.getLogger(FrameFlusher.class);
private static final Throwable CLOSED_CHANNEL = new ClosedChannelException();
private final AutoLock lock = new AutoLock();
private final LongAdder messagesOut = new LongAdder();
@ -185,7 +184,14 @@ public class FrameFlusher extends IteratingCallback
{
try (AutoLock l = lock.lock())
{
closedCause = cause == null ? CLOSED_CHANNEL : cause;
closedCause = cause == null ? new ClosedChannelException()
{
@Override
public Throwable fillInStackTrace()
{
return this;
}
} : cause;
}
iterate();
}

View File

@ -14,6 +14,7 @@
package org.eclipse.jetty.websocket.core.internal;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
@ -146,6 +147,17 @@ public class PerMessageDeflateExtension extends AbstractExtension implements Dem
super.init(configNegotiated, components);
}
@Override
public void close()
{
// TODO: use IteratingCallback.close() instead of creating exception with failFlusher methods.
ClosedChannelException exception = new ClosedChannelException();
incomingFlusher.failFlusher(exception);
outgoingFlusher.failFlusher(exception);
releaseInflater();
releaseDeflater();
}
private static String toDetail(Inflater inflater)
{
return String.format("Inflater[finished=%b,read=%d,written=%d,remaining=%d,in=%d,out=%d]", inflater.finished(), inflater.getBytesRead(),

View File

@ -77,6 +77,29 @@ public abstract class TransformingFlusher
notifyCallbackFailure(callback, failure);
}
/**
* Used to fail this flusher possibly from an external event such as a callback.
* @param t the failure.
*/
public void failFlusher(Throwable t)
{
boolean failed = false;
try (AutoLock l = lock.lock())
{
if (failure == null)
{
failure = t;
failed = true;
}
}
if (failed)
{
flusher.failed(t);
flusher.iterate();
}
}
private void onFailure(Throwable t)
{
try (AutoLock l = lock.lock())
@ -103,8 +126,14 @@ public abstract class TransformingFlusher
private FrameEntry current;
@Override
protected Action process()
protected Action process() throws Throwable
{
try (AutoLock l = lock.lock())
{
if (failure != null)
throw failure;
}
if (finished)
{
if (current != null)
@ -134,8 +163,11 @@ public abstract class TransformingFlusher
if (log.isDebugEnabled())
log.debug("onCompleteFailure {}", t.toString());
notifyCallbackFailure(current.callback, t);
current = null;
if (current != null)
{
notifyCallbackFailure(current.callback, t);
current = null;
}
onFailure(t);
}
}

View File

@ -254,12 +254,13 @@ public class WebSocketCoreSession implements IncomingFrames, CoreSession, Dumpab
closeConnection(sessionState.getCloseStatus(), Callback.NOOP);
}
public void closeConnection(CloseStatus closeStatus, Callback callback)
private void closeConnection(CloseStatus closeStatus, Callback callback)
{
if (LOG.isDebugEnabled())
LOG.debug("closeConnection() {} {}", closeStatus, this);
abort();
extensionStack.close();
// Forward Errors to Local WebSocket EndPoint
if (closeStatus.isAbnormal() && closeStatus.getCause() != null)

View File

@ -13,30 +13,47 @@
package org.eclipse.jetty.websocket.tests;
import java.io.IOException;
import java.net.URI;
import java.nio.ByteBuffer;
import java.time.Duration;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.util.BlockingArrayQueue;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.compression.CompressionPool;
import org.eclipse.jetty.util.compression.DeflaterPool;
import org.eclipse.jetty.util.compression.InflaterPool;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.StatusCode;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketClose;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketConnect;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketMessage;
import org.eclipse.jetty.websocket.api.annotations.WebSocket;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.common.WebSocketSession;
import org.eclipse.jetty.websocket.core.internal.WebSocketCoreSession;
import org.eclipse.jetty.websocket.server.JettyWebSocketServerContainer;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class PermessageDeflateBufferTest
@ -44,6 +61,10 @@ public class PermessageDeflateBufferTest
private Server server;
private ServerConnector connector;
private WebSocketClient client;
private JettyWebSocketServerContainer serverContainer;
private final FailEndPointOutgoing outgoingFailEndPoint = new FailEndPointOutgoing();
private final FailEndPointIncoming incomingFailEndPoint = new FailEndPointIncoming();
private final ServerSocket serverSocket = new ServerSocket();
// @checkstyle-disable-check : AvoidEscapedUnicodeCharactersCheck
private static final List<String> DICT = Arrays.asList(
@ -83,7 +104,7 @@ public class PermessageDeflateBufferTest
public void before() throws Exception
{
server = new Server();
connector = new ServerConnector(server);
connector = new ServerConnector(server, 1, 1);
server.addConnector(connector);
ServletContextHandler contextHandler = new ServletContextHandler(ServletContextHandler.SESSIONS);
@ -93,10 +114,13 @@ public class PermessageDeflateBufferTest
{
container.setMaxTextMessageSize(65535);
container.setInputBufferSize(16384);
container.addMapping("/", ServerSocket.class);
container.addMapping("/", (req, resp) -> serverSocket);
container.addMapping("/outgoingFail", (req, resp) -> outgoingFailEndPoint);
container.addMapping("/incomingFail", (req, resp) -> incomingFailEndPoint);
});
server.start();
serverContainer = JettyWebSocketServerContainer.getContainer(contextHandler.getServletContext());
client = new WebSocketClient();
client.start();
}
@ -157,4 +181,178 @@ public class PermessageDeflateBufferTest
assertTrue(socket.closeLatch.await(5, TimeUnit.SECONDS));
assertThat(socket.closeCode, equalTo(StatusCode.NORMAL));
}
@Test
public void testClientPartialMessageThenServerIdleTimeout() throws Exception
{
Duration idleTimeout = Duration.ofMillis(1000);
serverContainer.setIdleTimeout(idleTimeout);
ClientUpgradeRequest clientUpgradeRequest = new ClientUpgradeRequest();
clientUpgradeRequest.addExtensions("permessage-deflate");
EventSocket socket = new EventSocket();
URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/incomingFail");
Session session = client.connect(socket, uri, clientUpgradeRequest).get(5, TimeUnit.SECONDS);
session.getRemote().sendPartialString("partial", false);
// Wait for the idle timeout to elapse.
assertTrue(incomingFailEndPoint.closeLatch.await(5, TimeUnit.SECONDS));
server.getContainedBeans(InflaterPool.class).stream()
.map(CompressionPool::getPool)
.forEach(pool -> assertEquals(0, pool.getInUseCount(), "unreleased inflater pool entries: " + pool.dump()));
server.getContainedBeans(DeflaterPool.class).stream()
.map(CompressionPool::getPool)
.forEach(pool -> assertEquals(0, pool.getInUseCount(), "unreleased deflater pool entries: " + pool.dump()));
}
@Test
public void testClientPartialMessageThenClientClose() throws Exception
{
ClientUpgradeRequest clientUpgradeRequest = new ClientUpgradeRequest();
clientUpgradeRequest.addExtensions("permessage-deflate");
PartialTextSocket socket = new PartialTextSocket();
URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/incomingFail");
Session session = client.connect(socket, uri, clientUpgradeRequest).get(5, TimeUnit.SECONDS);
session.getRemote().sendPartialString("partial", false);
// Wait for the server to process the partial message.
assertThat(socket.partialMessages.poll(5, TimeUnit.SECONDS), equalTo("partial" + "last=true"));
// Abruptly close the connection from the client.
((WebSocketCoreSession)((WebSocketSession)session).getCoreSession()).getConnection().getEndPoint().close();
// Wait for the server to process the close.
assertTrue(incomingFailEndPoint.closeLatch.await(5, TimeUnit.SECONDS));
server.getContainedBeans(InflaterPool.class).stream()
.map(CompressionPool::getPool)
.forEach(pool -> assertEquals(0, pool.getInUseCount(), "unreleased inflater pool entries: " + pool.dump()));
server.getContainedBeans(DeflaterPool.class).stream()
.map(CompressionPool::getPool)
.forEach(pool -> assertEquals(0, pool.getInUseCount(), "unreleased deflater pool entries: " + pool.dump()));
}
@Test
public void testServerPartialMessageThenServerIdleTimeout() throws Exception
{
Duration idleTimeout = Duration.ofMillis(1000);
serverContainer.setIdleTimeout(idleTimeout);
ClientUpgradeRequest clientUpgradeRequest = new ClientUpgradeRequest();
clientUpgradeRequest.addExtensions("permessage-deflate");
EventSocket socket = new EventSocket();
URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/outgoingFail");
Session session = client.connect(socket, uri, clientUpgradeRequest).get(5, TimeUnit.SECONDS);
session.getRemote().sendString("hello");
// Wait for the idle timeout to elapse.
assertTrue(outgoingFailEndPoint.closeLatch.await(2 * idleTimeout.toMillis(), TimeUnit.SECONDS));
server.getContainedBeans(InflaterPool.class).stream()
.map(CompressionPool::getPool)
.forEach(pool -> assertEquals(0, pool.getInUseCount(), "unreleased inflater pool entries: " + pool.dump()));
server.getContainedBeans(DeflaterPool.class).stream()
.map(CompressionPool::getPool)
.forEach(pool -> assertEquals(0, pool.getInUseCount(), "unreleased deflater pool entries: " + pool.dump()));
}
@Test
public void testServerPartialMessageThenClientClose() throws Exception
{
ClientUpgradeRequest clientUpgradeRequest = new ClientUpgradeRequest();
clientUpgradeRequest.addExtensions("permessage-deflate");
PartialTextSocket socket = new PartialTextSocket();
URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/outgoingFail");
Session session = client.connect(socket, uri, clientUpgradeRequest).get(5, TimeUnit.SECONDS);
session.getRemote().sendString("hello");
// Wait for the server to process the message.
assertThat(socket.partialMessages.poll(5, TimeUnit.SECONDS), equalTo("hello" + "last=false"));
// Abruptly close the connection from the client.
((WebSocketCoreSession)((WebSocketSession)session).getCoreSession()).getConnection().getEndPoint().close();
// Wait for the server to process the close.
assertTrue(outgoingFailEndPoint.closeLatch.await(5, TimeUnit.SECONDS));
server.getContainedBeans(InflaterPool.class).stream()
.map(CompressionPool::getPool)
.forEach(pool -> assertEquals(0, pool.getInUseCount(), "unreleased inflater pool entries: " + pool.dump()));
server.getContainedBeans(DeflaterPool.class).stream()
.map(CompressionPool::getPool)
.forEach(pool -> assertEquals(0, pool.getInUseCount(), "unreleased deflater pool entries: " + pool.dump()));
}
@WebSocket
public static class PartialTextSocket
{
private static final Logger LOG = LoggerFactory.getLogger(EventSocket.class);
public Session session;
public BlockingQueue<String> partialMessages = new BlockingArrayQueue<>();
public CountDownLatch openLatch = new CountDownLatch(1);
public CountDownLatch closeLatch = new CountDownLatch(1);
@OnWebSocketConnect
public void onOpen(Session session)
{
this.session = session;
openLatch.countDown();
}
@OnWebSocketMessage
public void onMessage(String message, boolean last) throws IOException
{
partialMessages.offer(message + "last=" + last);
}
@OnWebSocketClose
public void onClose(int statusCode, String reason)
{
closeLatch.countDown();
}
}
@WebSocket
public static class FailEndPointOutgoing
{
public CountDownLatch closeLatch = new CountDownLatch(1);
@OnWebSocketMessage
public void onMessage(Session session, String message) throws IOException
{
session.getRemote().sendPartialString(message, false);
}
@OnWebSocketClose
public void onClose(int statusCode, String reason)
{
closeLatch.countDown();
}
}
@WebSocket
public static class FailEndPointIncoming
{
public CountDownLatch closeLatch = new CountDownLatch(1);
@OnWebSocketMessage
public void onMessage(Session session, String message, boolean last) throws IOException
{
session.getRemote().sendString(message);
}
@OnWebSocketClose
public void onClose(int statusCode, String reason)
{
closeLatch.countDown();
}
}
}