Issue #9682 - fix RetainableByteBuffer release bug in WebSocket

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
Lachlan Roberts 2023-05-11 10:58:42 +10:00
parent 413a644a44
commit c8b8ef6bd5
3 changed files with 86 additions and 13 deletions

View File

@ -220,6 +220,15 @@ public class WebSocketConnection extends AbstractConnection implements Connectio
if (!coreSession.isClosed()) if (!coreSession.isClosed())
coreSession.onEof(); coreSession.onEof();
flusher.onClose(cause); flusher.onClose(cause);
try (AutoLock l = lock.lock())
{
if (networkBuffer != null)
{
networkBuffer.clear();
releaseNetworkBuffer();
}
}
super.onClose(cause); super.onClose(cause);
} }

View File

@ -81,8 +81,8 @@ public class JettyWebSocketFrameHandler implements FrameHandler
private MessageSink activeMessageSink; private MessageSink activeMessageSink;
private WebSocketSession session; private WebSocketSession session;
private SuspendState state = SuspendState.DEMANDING; private SuspendState state = SuspendState.DEMANDING;
private Runnable delayedOnFrame; private Frame delayedFrame;
private CoreSession coreSession; private Callback delayedCallback;
public JettyWebSocketFrameHandler(WebSocketContainer container, public JettyWebSocketFrameHandler(WebSocketContainer container,
Object endpointInstance, Object endpointInstance,
@ -151,7 +151,6 @@ public class JettyWebSocketFrameHandler implements FrameHandler
try try
{ {
customizer.customize(coreSession); customizer.customize(coreSession);
this.coreSession = coreSession;
session = new WebSocketSession(container, coreSession, this); session = new WebSocketSession(container, coreSession, this);
if (!session.isOpen()) if (!session.isOpen())
throw new IllegalStateException("Session is not open"); throw new IllegalStateException("Session is not open");
@ -199,7 +198,8 @@ public class JettyWebSocketFrameHandler implements FrameHandler
break; break;
case SUSPENDING: case SUSPENDING:
delayedOnFrame = () -> onFrame(frame, callback); delayedFrame = frame;
delayedCallback = callback;
state = SuspendState.SUSPENDED; state = SuspendState.SUSPENDED;
return; return;
@ -283,12 +283,19 @@ public class JettyWebSocketFrameHandler implements FrameHandler
@Override @Override
public void onClosed(CloseStatus closeStatus, Callback callback) public void onClosed(CloseStatus closeStatus, Callback callback)
{ {
Callback suspendedCallback;
try (AutoLock l = lock.lock()) try (AutoLock l = lock.lock())
{ {
// We are now closed and cannot suspend or resume. // We are now closed and cannot suspend or resume.
state = SuspendState.CLOSED; state = SuspendState.CLOSED;
delayedFrame = null;
suspendedCallback = delayedCallback;
delayedCallback = null;
} }
if (suspendedCallback != null)
suspendedCallback.failed(new CloseException(closeStatus.getCode(), closeStatus.getCause()));
notifyOnClose(closeStatus, callback); notifyOnClose(closeStatus, callback);
container.notifySessionListeners((listener) -> listener.onWebSocketSessionClosed(session)); container.notifySessionListeners((listener) -> listener.onWebSocketSessionClosed(session));
} }
@ -447,7 +454,8 @@ public class JettyWebSocketFrameHandler implements FrameHandler
public void resume() public void resume()
{ {
boolean needDemand = false; boolean needDemand = false;
Runnable delayedFrame = null; Frame frame = null;
Callback callback = null;
try (AutoLock l = lock.lock()) try (AutoLock l = lock.lock())
{ {
switch (state) switch (state)
@ -457,13 +465,13 @@ public class JettyWebSocketFrameHandler implements FrameHandler
case SUSPENDED: case SUSPENDED:
needDemand = true; needDemand = true;
delayedFrame = delayedOnFrame; frame = delayedFrame;
delayedOnFrame = null; callback = delayedCallback;
state = SuspendState.DEMANDING; state = SuspendState.DEMANDING;
break; break;
case SUSPENDING: case SUSPENDING:
if (delayedOnFrame != null) if (delayedFrame != null)
throw new IllegalStateException(); throw new IllegalStateException();
state = SuspendState.DEMANDING; state = SuspendState.DEMANDING;
break; break;
@ -475,8 +483,8 @@ public class JettyWebSocketFrameHandler implements FrameHandler
if (needDemand) if (needDemand)
{ {
if (delayedFrame != null) if (frame != null)
delayedFrame.run(); onFrame(frame, callback);
else else
session.getCoreSession().demand(1); session.getCoreSession().demand(1);
} }

View File

@ -15,16 +15,21 @@ package org.eclipse.jetty.websocket.tests;
import java.io.IOException; import java.io.IOException;
import java.net.URI; import java.net.URI;
import java.time.Duration;
import java.util.concurrent.Future; import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.io.ArrayRetainableByteBufferPool;
import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector; import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.servlet.ServletContextHandler; import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.servlet.ServletHolder; import org.eclipse.jetty.servlet.ServletHolder;
import org.eclipse.jetty.websocket.api.BatchMode;
import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.StatusCode;
import org.eclipse.jetty.websocket.api.SuspendToken; import org.eclipse.jetty.websocket.api.SuspendToken;
import org.eclipse.jetty.websocket.api.annotations.WebSocket; import org.eclipse.jetty.websocket.api.annotations.WebSocket;
import org.eclipse.jetty.websocket.api.exceptions.WebSocketTimeoutException;
import org.eclipse.jetty.websocket.client.WebSocketClient; import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.JettyWebSocketServlet; import org.eclipse.jetty.websocket.server.JettyWebSocketServlet;
import org.eclipse.jetty.websocket.server.JettyWebSocketServletFactory; import org.eclipse.jetty.websocket.server.JettyWebSocketServletFactory;
@ -34,7 +39,10 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@ -64,14 +72,15 @@ public class SuspendResumeTest
} }
} }
private Server server = new Server(); private Server server;
private WebSocketClient client = new WebSocketClient(); private WebSocketClient client;
private SuspendSocket serverSocket = new SuspendSocket(); private SuspendSocket serverSocket;
private ServerConnector connector; private ServerConnector connector;
@BeforeEach @BeforeEach
public void start() throws Exception public void start() throws Exception
{ {
server = new Server();
connector = new ServerConnector(server); connector = new ServerConnector(server);
server.addConnector(connector); server.addConnector(connector);
@ -79,10 +88,12 @@ public class SuspendResumeTest
contextHandler.setContextPath("/"); contextHandler.setContextPath("/");
server.setHandler(contextHandler); server.setHandler(contextHandler);
contextHandler.addServlet(new ServletHolder(new UpgradeServlet()), "/suspend"); contextHandler.addServlet(new ServletHolder(new UpgradeServlet()), "/suspend");
serverSocket = new SuspendSocket();
JettyWebSocketServletContainerInitializer.configure(contextHandler, null); JettyWebSocketServletContainerInitializer.configure(contextHandler, null);
server.start(); server.start();
client = new WebSocketClient();
client.start(); client.start();
} }
@ -189,4 +200,49 @@ public class SuspendResumeTest
// suspend after closed throws ISE // suspend after closed throws ISE
assertThrows(IllegalStateException.class, () -> clientSocket.session.suspend()); assertThrows(IllegalStateException.class, () -> clientSocket.session.suspend());
} }
@Test
public void timeoutWhileSuspended() throws Exception
{
URI uri = new URI("ws://localhost:" + connector.getLocalPort() + "/suspend");
EventSocket clientSocket = new EventSocket();
Future<Session> connect = client.connect(clientSocket, uri);
connect.get(5, TimeUnit.SECONDS);
assertTrue(serverSocket.openLatch.await(5, TimeUnit.SECONDS));
// Set short idleTimeout on server.
int idleTimeout = 1000;
serverSocket.session.setIdleTimeout(Duration.ofMillis(idleTimeout));
// Suspend on the server.
clientSocket.session.getRemote().sendString("suspend");
assertThat(serverSocket.textMessages.poll(5, TimeUnit.SECONDS), is("suspend"));
// Send two messages, with batching on, so they are read into same network buffer on the server.
// First frame is read and delayed inside the JettyWebSocketFrameHandler suspendState, second frame remains in the network buffer.
clientSocket.session.getRemote().setBatchMode(BatchMode.ON);
clientSocket.session.getRemote().sendString("no demand");
clientSocket.session.getRemote().sendString("this should sit in network buffer");
clientSocket.session.getRemote().flush();
assertNotNull(serverSocket.suspendToken);
// Make sure both sides are closed.
assertTrue(serverSocket.closeLatch.await(5, TimeUnit.SECONDS));
assertTrue(clientSocket.closeLatch.await(5, TimeUnit.SECONDS));
// We received no additional messages.
assertNull(serverSocket.textMessages.poll());
assertNull(serverSocket.binaryMessages.poll());
// Check the idleTimeout occurred.
assertThat(serverSocket.error, instanceOf(WebSocketTimeoutException.class));
assertNull(clientSocket.error);
assertThat(clientSocket.closeCode, equalTo(StatusCode.SHUTDOWN));
assertThat(clientSocket.closeReason, equalTo("Connection Idle Timeout"));
// We should have no used buffers in the pool.
ArrayRetainableByteBufferPool pool = (ArrayRetainableByteBufferPool)connector.getByteBufferPool().asRetainableByteBufferPool();
assertThat(pool.getHeapByteBufferCount(), equalTo(pool.getAvailableHeapByteBufferCount()));
assertThat(pool.getDirectByteBufferCount(), equalTo(pool.getAvailableDirectByteBufferCount()));
}
} }