diff --git a/jetty-websocket/jetty-websocket-tests/src/test/java/org/eclipse/jetty/websocket/tests/server/PartialListenerTest.java b/jetty-websocket/jetty-websocket-tests/src/test/java/org/eclipse/jetty/websocket/tests/server/PartialListenerTest.java new file mode 100644 index 00000000000..0a895a8c456 --- /dev/null +++ b/jetty-websocket/jetty-websocket-tests/src/test/java/org/eclipse/jetty/websocket/tests/server/PartialListenerTest.java @@ -0,0 +1,301 @@ +// +// ======================================================================== +// Copyright (c) 1995-2019 Mort Bay Consulting Pty. Ltd. +// ------------------------------------------------------------------------ +// All rights reserved. This program and the accompanying materials +// are made available under the terms of the Eclipse Public License v1.0 +// and Apache License v2.0 which accompanies this distribution. +// +// The Eclipse Public License is available at +// http://www.eclipse.org/legal/epl-v10.html +// +// The Apache License v2.0 is available at +// http://www.opensource.org/licenses/apache2.0.php +// +// You may elect to redistribute this code under either of these licenses. +// ======================================================================== +// + +package org.eclipse.jetty.websocket.tests.server; + +import java.net.URI; +import java.nio.ByteBuffer; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Future; +import java.util.concurrent.LinkedBlockingQueue; + +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.server.ServerConnector; +import org.eclipse.jetty.server.handler.DefaultHandler; +import org.eclipse.jetty.server.handler.HandlerList; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.servlet.ServletHolder; +import org.eclipse.jetty.util.BufferUtil; +import org.eclipse.jetty.util.log.StacklessLogging; +import org.eclipse.jetty.websocket.api.RemoteEndpoint; +import org.eclipse.jetty.websocket.api.Session; +import org.eclipse.jetty.websocket.api.WebSocketPartialListener; +import org.eclipse.jetty.websocket.api.util.WSURI; +import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; +import org.eclipse.jetty.websocket.client.WebSocketClient; +import org.eclipse.jetty.websocket.common.WebSocketSession; +import org.eclipse.jetty.websocket.common.util.TextUtil; +import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest; +import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse; +import org.eclipse.jetty.websocket.servlet.WebSocketCreator; +import org.eclipse.jetty.websocket.servlet.WebSocketServlet; +import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory; +import org.eclipse.jetty.websocket.tests.CloseTrackingEndpoint; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; + +public class PartialListenerTest +{ + private Server server; + private PartialCreator partialCreator; + private WebSocketClient client; + + @BeforeEach + public void startServer() throws Exception + { + server = new Server(); + + ServerConnector connector = new ServerConnector(server); + server.addConnector(connector); + + ServletContextHandler context = new ServletContextHandler(); + context.setContextPath("/"); + + ServletHolder closeEndpoint = new ServletHolder(new WebSocketServlet() + { + @Override + public void configure(WebSocketServletFactory factory) + { + factory.getPolicy().setIdleTimeout(SECONDS.toMillis(2)); + partialCreator = new PartialCreator(); + factory.setCreator(partialCreator); + } + }); + context.addServlet(closeEndpoint, "/ws"); + + HandlerList handlers = new HandlerList(); + handlers.addHandler(context); + handlers.addHandler(new DefaultHandler()); + + server.setHandler(handlers); + + server.start(); + } + + @AfterEach + public void stopServer() throws Exception + { + server.stop(); + } + + @BeforeEach + public void startClient() throws Exception + { + client = new WebSocketClient(); + client.start(); + } + + @AfterEach + public void stopClient() throws Exception + { + client.stop(); + } + + private void close(Session session) + { + if (session != null) + { + session.close(); + } + } + + @Test + public void testPartialText() throws Exception + { + ClientUpgradeRequest request = new ClientUpgradeRequest(); + CloseTrackingEndpoint clientEndpoint = new CloseTrackingEndpoint(); + + URI wsUri = WSURI.toWebsocket(server.getURI().resolve("/ws")); + Future futSession = client.connect(clientEndpoint, wsUri, request); + + Session session = null; + try (StacklessLogging ignore = new StacklessLogging(WebSocketSession.class)) + { + session = futSession.get(5, SECONDS); + + RemoteEndpoint clientRemote = session.getRemote(); + clientRemote.sendPartialString("hello", false); + clientRemote.sendPartialString(" ", false); + clientRemote.sendPartialString("world", true); + + PartialEndpoint serverEndpoint = partialCreator.partialEndpoint; + + String event = serverEndpoint.partialEvents.poll(5, SECONDS); + assertThat("Event", event, is("TEXT[payload=hello, fin=false]")); + event = serverEndpoint.partialEvents.poll(5, SECONDS); + assertThat("Event", event, is("TEXT[payload= , fin=false]")); + event = serverEndpoint.partialEvents.poll(5, SECONDS); + assertThat("Event", event, is("TEXT[payload=world, fin=true]")); + } + finally + { + close(session); + } + } + + @Test + public void testPartialBinary() throws Exception + { + ClientUpgradeRequest request = new ClientUpgradeRequest(); + CloseTrackingEndpoint clientEndpoint = new CloseTrackingEndpoint(); + + URI wsUri = WSURI.toWebsocket(server.getURI().resolve("/ws")); + Future futSession = client.connect(clientEndpoint, wsUri, request); + + Session session = null; + try (StacklessLogging ignore = new StacklessLogging(WebSocketSession.class)) + { + session = futSession.get(5, SECONDS); + + RemoteEndpoint clientRemote = session.getRemote(); + clientRemote.sendPartialBytes(BufferUtil.toBuffer("hello"), false); + clientRemote.sendPartialBytes(BufferUtil.toBuffer(" "), false); + clientRemote.sendPartialBytes(BufferUtil.toBuffer("world"), true); + + PartialEndpoint serverEndpoint = partialCreator.partialEndpoint; + + String event = serverEndpoint.partialEvents.poll(5, SECONDS); + assertThat("Event", event, is("BINARY[payload=<<>>, fin=false]")); + event = serverEndpoint.partialEvents.poll(5, SECONDS); + assertThat("Event", event, is("BINARY[payload=<<< >>>, fin=false]")); + event = serverEndpoint.partialEvents.poll(5, SECONDS); + assertThat("Event", event, is("BINARY[payload=<<>>, fin=true]")); + } + finally + { + close(session); + } + } + + /** + * Test to ensure that the internal state tracking the partial messages is reset after each complete message. + */ + @Test + public void testPartial_TextBinaryText() throws Exception + { + ClientUpgradeRequest request = new ClientUpgradeRequest(); + CloseTrackingEndpoint clientEndpoint = new CloseTrackingEndpoint(); + + URI wsUri = WSURI.toWebsocket(server.getURI().resolve("/ws")); + Future futSession = client.connect(clientEndpoint, wsUri, request); + + Session session = null; + try (StacklessLogging ignore = new StacklessLogging(WebSocketSession.class)) + { + session = futSession.get(5, SECONDS); + + RemoteEndpoint clientRemote = session.getRemote(); + clientRemote.sendPartialString("hello", false); + clientRemote.sendPartialString(" ", false); + clientRemote.sendPartialString("world", true); + + clientRemote.sendPartialBytes(BufferUtil.toBuffer("greetings"), false); + clientRemote.sendPartialBytes(BufferUtil.toBuffer(" "), false); + clientRemote.sendPartialBytes(BufferUtil.toBuffer("mars"), true); + + clientRemote.sendPartialString("salutations", false); + clientRemote.sendPartialString(" ", false); + clientRemote.sendPartialString("phobos", true); + + PartialEndpoint serverEndpoint = partialCreator.partialEndpoint; + + String event; + + event = serverEndpoint.partialEvents.poll(5, SECONDS); + assertThat("Event", event, is("TEXT[payload=hello, fin=false]")); + event = serverEndpoint.partialEvents.poll(5, SECONDS); + assertThat("Event", event, is("TEXT[payload= , fin=false]")); + event = serverEndpoint.partialEvents.poll(5, SECONDS); + assertThat("Event", event, is("TEXT[payload=world, fin=true]")); + + event = serverEndpoint.partialEvents.poll(5, SECONDS); + assertThat("Event", event, is("BINARY[payload=<<>>, fin=false]")); + event = serverEndpoint.partialEvents.poll(5, SECONDS); + assertThat("Event", event, is("BINARY[payload=<<< >>>, fin=false]")); + event = serverEndpoint.partialEvents.poll(5, SECONDS); + assertThat("Event", event, is("BINARY[payload=<<>>, fin=true]")); + + event = serverEndpoint.partialEvents.poll(5, SECONDS); + assertThat("Event", event, is("TEXT[payload=salutations, fin=false]")); + event = serverEndpoint.partialEvents.poll(5, SECONDS); + assertThat("Event", event, is("TEXT[payload= , fin=false]")); + event = serverEndpoint.partialEvents.poll(5, SECONDS); + assertThat("Event", event, is("TEXT[payload=phobos, fin=true]")); + } + finally + { + close(session); + } + } + + + public static class PartialCreator implements WebSocketCreator + { + public PartialEndpoint partialEndpoint; + + @Override + public Object createWebSocket(ServletUpgradeRequest req, ServletUpgradeResponse resp) + { + partialEndpoint = new PartialEndpoint(); + return partialEndpoint; + } + } + + public static class PartialEndpoint implements WebSocketPartialListener + { + public Session session; + public CountDownLatch closeLatch = new CountDownLatch(1); + public LinkedBlockingQueue partialEvents = new LinkedBlockingQueue<>(); + + @Override + public void onWebSocketClose(int statusCode, String reason) + { + closeLatch.countDown(); + } + + @Override + public void onWebSocketConnect(Session session) + { + this.session = session; + } + + @Override + public void onWebSocketError(Throwable cause) + { + cause.printStackTrace(System.err); + } + + @Override + public void onWebSocketPartialBinary(ByteBuffer payload, boolean fin) + { + // our testcases always send bytes limited in the US-ASCII range. + partialEvents.offer(String.format("BINARY[payload=<<<%s>>>, fin=%b]", BufferUtil.toUTF8String(payload), fin)); + } + + @Override + public void onWebSocketPartialText(String payload, boolean fin) + { + partialEvents.offer(String.format("TEXT[payload=%s, fin=%b]", TextUtil.maxStringLength(30, payload), fin)); + } + } +} diff --git a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/events/JettyListenerEventDriver.java b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/events/JettyListenerEventDriver.java index 244460fadf8..125b68db913 100644 --- a/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/events/JettyListenerEventDriver.java +++ b/jetty-websocket/websocket-common/src/main/java/org/eclipse/jetty/websocket/common/events/JettyListenerEventDriver.java @@ -23,6 +23,7 @@ import java.io.InputStream; import java.io.Reader; import java.nio.ByteBuffer; +import org.eclipse.jetty.util.Utf8StringBuilder; import org.eclipse.jetty.util.log.Log; import org.eclipse.jetty.util.log.Logger; import org.eclipse.jetty.websocket.api.WebSocketConnectionListener; @@ -37,16 +38,21 @@ import org.eclipse.jetty.websocket.common.CloseInfo; import org.eclipse.jetty.websocket.common.frames.ReadOnlyDelegatedFrame; import org.eclipse.jetty.websocket.common.message.SimpleBinaryMessage; import org.eclipse.jetty.websocket.common.message.SimpleTextMessage; -import org.eclipse.jetty.websocket.common.util.Utf8PartialBuilder; /** * Handler for {@link WebSocketListener} based User WebSocket implementations. */ public class JettyListenerEventDriver extends AbstractEventDriver { + private enum PartialMode + { + NONE, TEXT, BINARY + } + private static final Logger LOG = Log.getLogger(JettyListenerEventDriver.class); private final WebSocketConnectionListener listener; - private Utf8PartialBuilder utf8Partial; + private Utf8StringBuilder utf8Partial; + private PartialMode partialMode = PartialMode.NONE; private boolean hasCloseBeenCalled = false; public JettyListenerEventDriver(WebSocketPolicy policy, WebSocketConnectionListener listener) @@ -70,7 +76,22 @@ public class JettyListenerEventDriver extends AbstractEventDriver if (listener instanceof WebSocketPartialListener) { - ((WebSocketPartialListener)listener).onWebSocketPartialBinary(buffer.slice().asReadOnlyBuffer(), fin); + switch (partialMode) + { + case NONE: + partialMode = PartialMode.BINARY; + // fallthru + case BINARY: + ((WebSocketPartialListener)listener).onWebSocketPartialBinary(buffer.slice().asReadOnlyBuffer(), fin); + break; + case TEXT: + throw new IOException("Out of order binary frame encountered"); + } + + if (fin) + { + partialMode = PartialMode.NONE; + } } } @@ -160,18 +181,39 @@ public class JettyListenerEventDriver extends AbstractEventDriver if (listener instanceof WebSocketPartialListener) { - if (utf8Partial == null) + switch (partialMode) { - utf8Partial = new Utf8PartialBuilder(); + case NONE: + partialMode = PartialMode.TEXT; + // fallthru + case TEXT: + if (utf8Partial == null) + { + utf8Partial = new Utf8StringBuilder(); + } + + String partial = ""; + + if (buffer != null) + { + utf8Partial.append(buffer); + partial = utf8Partial.takePartialString(); + } + + ((WebSocketPartialListener)listener).onWebSocketPartialText(partial, fin); + + if (fin) + { + utf8Partial = null; + } + break; + case BINARY: + throw new IOException("Out of order text frame encountered"); } - String partial = utf8Partial.toPartialString(buffer); - - ((WebSocketPartialListener)listener).onWebSocketPartialText(partial, fin); - if (fin) { - partial = null; + partialMode = PartialMode.NONE; } } } @@ -190,6 +232,27 @@ public class JettyListenerEventDriver extends AbstractEventDriver } } + public void onContinuationFrame(ByteBuffer buffer, boolean fin) throws IOException + { + if (listener instanceof WebSocketPartialListener) + { + switch (partialMode) + { + case NONE: + throw new IOException("Out of order Continuation frame encountered"); + case TEXT: + onTextFrame(buffer, fin); + break; + case BINARY: + onBinaryFrame(buffer, fin); + break; + } + return; + } + + super.onContinuationFrame(buffer, fin); + } + @Override public String toString() {