diff --git a/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/PerMessageDeflateExtension.java b/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/PerMessageDeflateExtension.java index 2f8f54eb6d6..14cb66a15e3 100644 --- a/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/PerMessageDeflateExtension.java +++ b/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/PerMessageDeflateExtension.java @@ -470,14 +470,25 @@ public class PerMessageDeflateExtension extends AbstractExtension implements Dem chunk.setPayload(byteBuffer); chunk.setFin(frame.isFin() && complete); - // Capture the current AtomicReference. + // If we are complete we return true, then DemandingFlusher.process() will null out the Frame and Callback. + // The application may decide to hold onto the buffer and delay completing the callback, so we need to capture + // references to these in the payloadCallback and not rely on state of the flusher which may have moved on. + // This flusher could be failed while the application already has the payloadCallback, so we need protection against + // the flusher failing and the application completing the callback, that's why we use the payload AtomicReference. + boolean completeCallback = complete; AtomicReference payloadRef = _payloadRef; - boolean succeedCallback = complete; - Callback payloadCallback = Callback.from(() -> releasePayload(payloadRef), Callback.from(() -> + Callback payloadCallback = Callback.from(() -> { - if (succeedCallback) + releasePayload(payloadRef); + if (completeCallback) callback.succeeded(); - }, this::failFlusher)); + }, t -> + { + releasePayload(payloadRef); + if (completeCallback) + callback.failed(t); + failFlusher(t); + }); emitFrame(chunk, payloadCallback); if (LOG.isDebugEnabled()) diff --git a/jetty-ee9/jetty-ee9-websocket/jetty-ee9-websocket-jetty-tests/src/test/java/org/eclipse/jetty/ee9/websocket/tests/LargeDeflateTest.java b/jetty-ee9/jetty-ee9-websocket/jetty-ee9-websocket-jetty-tests/src/test/java/org/eclipse/jetty/ee9/websocket/tests/LargeDeflateTest.java index d3d36b71ebf..1bb9a218219 100644 --- a/jetty-ee9/jetty-ee9-websocket/jetty-ee9-websocket-jetty-tests/src/test/java/org/eclipse/jetty/ee9/websocket/tests/LargeDeflateTest.java +++ b/jetty-ee9/jetty-ee9-websocket/jetty-ee9-websocket-jetty-tests/src/test/java/org/eclipse/jetty/ee9/websocket/tests/LargeDeflateTest.java @@ -33,6 +33,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -92,6 +93,27 @@ public class LargeDeflateTest assertThat(message, is(sentMessage)); } + @Test + void testDeflateLargerThanMaxMessage() throws Exception + { + ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest(); + upgradeRequest.addExtensions("permessage-deflate"); + + EventSocket clientSocket = new EventSocket(); + ByteBuffer message = largePayloads(); + Session session = _client.connect(clientSocket, URI.create("ws://localhost:" + _connector.getLocalPort() + "/ws"), upgradeRequest).get(); + + // Set the maxBinaryMessageSize on the server to be lower than the size of the message. + assertTrue(_serverSocket.openLatch.await(5, TimeUnit.SECONDS)); + _serverSocket.session.setMaxBinaryMessageSize(message.remaining() - 1024); + + session.getRemote().sendBytes(message); + assertTrue(clientSocket.closeLatch.await(5, TimeUnit.SECONDS)); + assertTrue(_serverSocket.closeLatch.await(5, TimeUnit.SECONDS)); + assertThat(_serverSocket.closeCode, is(StatusCode.MESSAGE_TOO_LARGE)); + assertThat(_serverSocket.closeReason, containsString("Binary message too large")); + } + private static ByteBuffer largePayloads() { var bytes = new byte[4 * 1024 * 1024];