Merge pull request #8134 from eclipse/jetty-10.0.x-websocketPermessageDeflatePools

Improve cleanup of deflater/inflater pools for PerMessageDeflateExtension
This commit is contained in:
Lachlan 2022-06-10 09:43:23 +10:00 committed by GitHub
commit b1c19c0b0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 309 additions and 11 deletions

View File

@ -16,9 +16,11 @@ package org.eclipse.jetty.util.compression;
import java.io.Closeable; import java.io.Closeable;
import org.eclipse.jetty.util.Pool; import org.eclipse.jetty.util.Pool;
import org.eclipse.jetty.util.component.AbstractLifeCycle; import org.eclipse.jetty.util.annotation.ManagedObject;
import org.eclipse.jetty.util.component.ContainerLifeCycle;
public abstract class CompressionPool<T> extends AbstractLifeCycle @ManagedObject
public abstract class CompressionPool<T> extends ContainerLifeCycle
{ {
public static final int DEFAULT_CAPACITY = 1024; public static final int DEFAULT_CAPACITY = 1024;
@ -51,6 +53,11 @@ public abstract class CompressionPool<T> extends AbstractLifeCycle
_capacity = capacity; _capacity = capacity;
} }
public Pool<Entry> getPool()
{
return _pool;
}
protected abstract T newPooled(); protected abstract T newPooled();
protected abstract void end(T object); protected abstract void end(T object);
@ -85,7 +92,10 @@ public abstract class CompressionPool<T> extends AbstractLifeCycle
protected void doStart() throws Exception protected void doStart() throws Exception
{ {
if (_capacity > 0) if (_capacity > 0)
{
_pool = new Pool<>(Pool.StrategyType.RANDOM, _capacity, true); _pool = new Pool<>(Pool.StrategyType.RANDOM, _capacity, true);
addBean(_pool);
}
super.doStart(); super.doStart();
} }
@ -95,6 +105,7 @@ public abstract class CompressionPool<T> extends AbstractLifeCycle
if (_pool != null) if (_pool != null)
{ {
_pool.close(); _pool.close();
removeBean(_pool);
_pool = null; _pool = null;
} }
super.doStop(); super.doStop();

View File

@ -13,16 +13,25 @@
package org.eclipse.jetty.websocket.core; package org.eclipse.jetty.websocket.core;
import java.io.Closeable;
/** /**
* Interface for WebSocket Extensions. * Interface for WebSocket Extensions.
* <p> * <p>
* That {@link Frame}s are passed through the Extension via the {@link IncomingFrames} and {@link OutgoingFrames} interfaces * That {@link Frame}s are passed through the Extension via the {@link IncomingFrames} and {@link OutgoingFrames} interfaces
*/ */
public interface Extension extends IncomingFrames, OutgoingFrames public interface Extension extends IncomingFrames, OutgoingFrames, Closeable
{ {
void init(ExtensionConfig config, WebSocketComponents components); void init(ExtensionConfig config, WebSocketComponents components);
/**
* Used to clean up any resources after connection close.
*/
default void close()
{
}
/** /**
* The active configuration for this extension. * The active configuration for this extension.
* *

View File

@ -60,6 +60,22 @@ public class ExtensionStack implements IncomingFrames, OutgoingFrames, Dumpable
this.behavior = behavior; this.behavior = behavior;
} }
public void close()
{
for (Extension ext : extensions)
{
try
{
ext.close();
}
catch (Throwable t)
{
if (LOG.isDebugEnabled())
LOG.debug("Extension Error During Close", t);
}
}
}
@ManagedAttribute(name = "Extension List", readonly = true) @ManagedAttribute(name = "Extension List", readonly = true)
public List<Extension> getExtensions() public List<Extension> getExtensions()
{ {

View File

@ -44,7 +44,6 @@ public class FrameFlusher extends IteratingCallback
{ {
public static final Frame FLUSH_FRAME = new Frame(OpCode.BINARY); public static final Frame FLUSH_FRAME = new Frame(OpCode.BINARY);
private static final Logger LOG = LoggerFactory.getLogger(FrameFlusher.class); private static final Logger LOG = LoggerFactory.getLogger(FrameFlusher.class);
private static final Throwable CLOSED_CHANNEL = new ClosedChannelException();
private final AutoLock lock = new AutoLock(); private final AutoLock lock = new AutoLock();
private final LongAdder messagesOut = new LongAdder(); private final LongAdder messagesOut = new LongAdder();
@ -185,7 +184,15 @@ public class FrameFlusher extends IteratingCallback
{ {
try (AutoLock l = lock.lock()) try (AutoLock l = lock.lock())
{ {
closedCause = cause == null ? CLOSED_CHANNEL : cause; // TODO: find a way to not create exception if cause is null.
closedCause = cause == null ? new ClosedChannelException()
{
@Override
public Throwable fillInStackTrace()
{
return this;
}
} : cause;
} }
iterate(); iterate();
} }

View File

@ -14,6 +14,7 @@
package org.eclipse.jetty.websocket.core.internal; package org.eclipse.jetty.websocket.core.internal;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
@ -146,6 +147,24 @@ public class PerMessageDeflateExtension extends AbstractExtension implements Dem
super.init(configNegotiated, components); super.init(configNegotiated, components);
} }
@Override
public void close()
{
// TODO: use IteratingCallback.close() instead of creating exception with failFlusher methods.
ClosedChannelException exception = new ClosedChannelException()
{
@Override
public Throwable fillInStackTrace()
{
return this;
}
};
incomingFlusher.failFlusher(exception);
outgoingFlusher.failFlusher(exception);
releaseInflater();
releaseDeflater();
}
private static String toDetail(Inflater inflater) private static String toDetail(Inflater inflater)
{ {
return String.format("Inflater[finished=%b,read=%d,written=%d,remaining=%d,in=%d,out=%d]", inflater.finished(), inflater.getBytesRead(), return String.format("Inflater[finished=%b,read=%d,written=%d,remaining=%d,in=%d,out=%d]", inflater.finished(), inflater.getBytesRead(),

View File

@ -77,6 +77,34 @@ public abstract class TransformingFlusher
notifyCallbackFailure(callback, failure); notifyCallbackFailure(callback, failure);
} }
/**
* Used to fail this flusher possibly from an external event such as a callback.
* @param t the failure.
*/
public void failFlusher(Throwable t)
{
// TODO: find a way to close the flusher in non error case without exception.
boolean failed = false;
try (AutoLock l = lock.lock())
{
if (failure == null)
{
failure = t;
failed = true;
}
else
{
failure.addSuppressed(t);
}
}
if (failed)
{
flusher.failed(t);
flusher.iterate();
}
}
private void onFailure(Throwable t) private void onFailure(Throwable t)
{ {
try (AutoLock l = lock.lock()) try (AutoLock l = lock.lock())
@ -103,8 +131,14 @@ public abstract class TransformingFlusher
private FrameEntry current; private FrameEntry current;
@Override @Override
protected Action process() protected Action process() throws Throwable
{ {
try (AutoLock l = lock.lock())
{
if (failure != null)
throw failure;
}
if (finished) if (finished)
{ {
if (current != null) if (current != null)
@ -134,8 +168,11 @@ public abstract class TransformingFlusher
if (log.isDebugEnabled()) if (log.isDebugEnabled())
log.debug("onCompleteFailure {}", t.toString()); log.debug("onCompleteFailure {}", t.toString());
notifyCallbackFailure(current.callback, t); if (current != null)
current = null; {
notifyCallbackFailure(current.callback, t);
current = null;
}
onFailure(t); onFailure(t);
} }
} }

View File

@ -254,12 +254,13 @@ public class WebSocketCoreSession implements IncomingFrames, CoreSession, Dumpab
closeConnection(sessionState.getCloseStatus(), Callback.NOOP); closeConnection(sessionState.getCloseStatus(), Callback.NOOP);
} }
public void closeConnection(CloseStatus closeStatus, Callback callback) private void closeConnection(CloseStatus closeStatus, Callback callback)
{ {
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("closeConnection() {} {}", closeStatus, this); LOG.debug("closeConnection() {} {}", closeStatus, this);
abort(); abort();
extensionStack.close();
// Forward Errors to Local WebSocket EndPoint // Forward Errors to Local WebSocket EndPoint
if (closeStatus.isAbnormal() && closeStatus.getCause() != null) if (closeStatus.isAbnormal() && closeStatus.getCause() != null)

View File

@ -13,30 +13,47 @@
package org.eclipse.jetty.websocket.tests; package org.eclipse.jetty.websocket.tests;
import java.io.IOException;
import java.net.URI; import java.net.URI;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.time.Duration;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector; import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.servlet.ServletContextHandler; import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.util.BlockingArrayQueue;
import org.eclipse.jetty.util.BufferUtil; import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.compression.CompressionPool;
import org.eclipse.jetty.util.compression.DeflaterPool;
import org.eclipse.jetty.util.compression.InflaterPool;
import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.StatusCode; import org.eclipse.jetty.websocket.api.StatusCode;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketClose;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketConnect;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketMessage;
import org.eclipse.jetty.websocket.api.annotations.WebSocket; import org.eclipse.jetty.websocket.api.annotations.WebSocket;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient; import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.common.WebSocketSession;
import org.eclipse.jetty.websocket.core.internal.WebSocketCoreSession;
import org.eclipse.jetty.websocket.server.JettyWebSocketServerContainer;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class PermessageDeflateBufferTest public class PermessageDeflateBufferTest
@ -44,6 +61,10 @@ public class PermessageDeflateBufferTest
private Server server; private Server server;
private ServerConnector connector; private ServerConnector connector;
private WebSocketClient client; private WebSocketClient client;
private JettyWebSocketServerContainer serverContainer;
private final FailEndPointOutgoing outgoingFailEndPoint = new FailEndPointOutgoing();
private final FailEndPointIncoming incomingFailEndPoint = new FailEndPointIncoming();
private final ServerSocket serverSocket = new ServerSocket();
// @checkstyle-disable-check : AvoidEscapedUnicodeCharactersCheck // @checkstyle-disable-check : AvoidEscapedUnicodeCharactersCheck
private static final List<String> DICT = Arrays.asList( private static final List<String> DICT = Arrays.asList(
@ -83,7 +104,7 @@ public class PermessageDeflateBufferTest
public void before() throws Exception public void before() throws Exception
{ {
server = new Server(); server = new Server();
connector = new ServerConnector(server); connector = new ServerConnector(server, 1, 1);
server.addConnector(connector); server.addConnector(connector);
ServletContextHandler contextHandler = new ServletContextHandler(ServletContextHandler.SESSIONS); ServletContextHandler contextHandler = new ServletContextHandler(ServletContextHandler.SESSIONS);
@ -93,10 +114,13 @@ public class PermessageDeflateBufferTest
{ {
container.setMaxTextMessageSize(65535); container.setMaxTextMessageSize(65535);
container.setInputBufferSize(16384); container.setInputBufferSize(16384);
container.addMapping("/", ServerSocket.class); container.addMapping("/", (req, resp) -> serverSocket);
container.addMapping("/outgoingFail", (req, resp) -> outgoingFailEndPoint);
container.addMapping("/incomingFail", (req, resp) -> incomingFailEndPoint);
}); });
server.start(); server.start();
serverContainer = JettyWebSocketServerContainer.getContainer(contextHandler.getServletContext());
client = new WebSocketClient(); client = new WebSocketClient();
client.start(); client.start();
} }
@ -157,4 +181,178 @@ public class PermessageDeflateBufferTest
assertTrue(socket.closeLatch.await(5, TimeUnit.SECONDS)); assertTrue(socket.closeLatch.await(5, TimeUnit.SECONDS));
assertThat(socket.closeCode, equalTo(StatusCode.NORMAL)); assertThat(socket.closeCode, equalTo(StatusCode.NORMAL));
} }
@Test
public void testClientPartialMessageThenServerIdleTimeout() throws Exception
{
Duration idleTimeout = Duration.ofMillis(1000);
serverContainer.setIdleTimeout(idleTimeout);
ClientUpgradeRequest clientUpgradeRequest = new ClientUpgradeRequest();
clientUpgradeRequest.addExtensions("permessage-deflate");
EventSocket socket = new EventSocket();
URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/incomingFail");
Session session = client.connect(socket, uri, clientUpgradeRequest).get(5, TimeUnit.SECONDS);
session.getRemote().sendPartialString("partial", false);
// Wait for the idle timeout to elapse.
assertTrue(incomingFailEndPoint.closeLatch.await(5, TimeUnit.SECONDS));
server.getContainedBeans(InflaterPool.class).stream()
.map(CompressionPool::getPool)
.forEach(pool -> assertEquals(0, pool.getInUseCount(), "unreleased inflater pool entries: " + pool.dump()));
server.getContainedBeans(DeflaterPool.class).stream()
.map(CompressionPool::getPool)
.forEach(pool -> assertEquals(0, pool.getInUseCount(), "unreleased deflater pool entries: " + pool.dump()));
}
@Test
public void testClientPartialMessageThenClientClose() throws Exception
{
ClientUpgradeRequest clientUpgradeRequest = new ClientUpgradeRequest();
clientUpgradeRequest.addExtensions("permessage-deflate");
PartialTextSocket socket = new PartialTextSocket();
URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/incomingFail");
Session session = client.connect(socket, uri, clientUpgradeRequest).get(5, TimeUnit.SECONDS);
session.getRemote().sendPartialString("partial", false);
// Wait for the server to process the partial message.
assertThat(socket.partialMessages.poll(5, TimeUnit.SECONDS), equalTo("partial" + "last=true"));
// Abruptly close the connection from the client.
((WebSocketCoreSession)((WebSocketSession)session).getCoreSession()).getConnection().getEndPoint().close();
// Wait for the server to process the close.
assertTrue(incomingFailEndPoint.closeLatch.await(5, TimeUnit.SECONDS));
server.getContainedBeans(InflaterPool.class).stream()
.map(CompressionPool::getPool)
.forEach(pool -> assertEquals(0, pool.getInUseCount(), "unreleased inflater pool entries: " + pool.dump()));
server.getContainedBeans(DeflaterPool.class).stream()
.map(CompressionPool::getPool)
.forEach(pool -> assertEquals(0, pool.getInUseCount(), "unreleased deflater pool entries: " + pool.dump()));
}
@Test
public void testServerPartialMessageThenServerIdleTimeout() throws Exception
{
Duration idleTimeout = Duration.ofMillis(1000);
serverContainer.setIdleTimeout(idleTimeout);
ClientUpgradeRequest clientUpgradeRequest = new ClientUpgradeRequest();
clientUpgradeRequest.addExtensions("permessage-deflate");
EventSocket socket = new EventSocket();
URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/outgoingFail");
Session session = client.connect(socket, uri, clientUpgradeRequest).get(5, TimeUnit.SECONDS);
session.getRemote().sendString("hello");
// Wait for the idle timeout to elapse.
assertTrue(outgoingFailEndPoint.closeLatch.await(2 * idleTimeout.toMillis(), TimeUnit.SECONDS));
server.getContainedBeans(InflaterPool.class).stream()
.map(CompressionPool::getPool)
.forEach(pool -> assertEquals(0, pool.getInUseCount(), "unreleased inflater pool entries: " + pool.dump()));
server.getContainedBeans(DeflaterPool.class).stream()
.map(CompressionPool::getPool)
.forEach(pool -> assertEquals(0, pool.getInUseCount(), "unreleased deflater pool entries: " + pool.dump()));
}
@Test
public void testServerPartialMessageThenClientClose() throws Exception
{
ClientUpgradeRequest clientUpgradeRequest = new ClientUpgradeRequest();
clientUpgradeRequest.addExtensions("permessage-deflate");
PartialTextSocket socket = new PartialTextSocket();
URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/outgoingFail");
Session session = client.connect(socket, uri, clientUpgradeRequest).get(5, TimeUnit.SECONDS);
session.getRemote().sendString("hello");
// Wait for the server to process the message.
assertThat(socket.partialMessages.poll(5, TimeUnit.SECONDS), equalTo("hello" + "last=false"));
// Abruptly close the connection from the client.
((WebSocketCoreSession)((WebSocketSession)session).getCoreSession()).getConnection().getEndPoint().close();
// Wait for the server to process the close.
assertTrue(outgoingFailEndPoint.closeLatch.await(5, TimeUnit.SECONDS));
server.getContainedBeans(InflaterPool.class).stream()
.map(CompressionPool::getPool)
.forEach(pool -> assertEquals(0, pool.getInUseCount(), "unreleased inflater pool entries: " + pool.dump()));
server.getContainedBeans(DeflaterPool.class).stream()
.map(CompressionPool::getPool)
.forEach(pool -> assertEquals(0, pool.getInUseCount(), "unreleased deflater pool entries: " + pool.dump()));
}
@WebSocket
public static class PartialTextSocket
{
private static final Logger LOG = LoggerFactory.getLogger(EventSocket.class);
public Session session;
public BlockingQueue<String> partialMessages = new BlockingArrayQueue<>();
public CountDownLatch openLatch = new CountDownLatch(1);
public CountDownLatch closeLatch = new CountDownLatch(1);
@OnWebSocketConnect
public void onOpen(Session session)
{
this.session = session;
openLatch.countDown();
}
@OnWebSocketMessage
public void onMessage(String message, boolean last) throws IOException
{
partialMessages.offer(message + "last=" + last);
}
@OnWebSocketClose
public void onClose(int statusCode, String reason)
{
closeLatch.countDown();
}
}
@WebSocket
public static class FailEndPointOutgoing
{
public CountDownLatch closeLatch = new CountDownLatch(1);
@OnWebSocketMessage
public void onMessage(Session session, String message) throws IOException
{
session.getRemote().sendPartialString(message, false);
}
@OnWebSocketClose
public void onClose(int statusCode, String reason)
{
closeLatch.countDown();
}
}
@WebSocket
public static class FailEndPointIncoming
{
public CountDownLatch closeLatch = new CountDownLatch(1);
@OnWebSocketMessage
public void onMessage(Session session, String message, boolean last) throws IOException
{
session.getRemote().sendString(message);
}
@OnWebSocketClose
public void onClose(int statusCode, String reason)
{
closeLatch.countDown();
}
}
} }