diff --git a/jetty-websocket/websocket-javax-tests/src/test/java/org/eclipse/jetty/websocket/javax/tests/server/TextStreamTest.java b/jetty-websocket/websocket-javax-tests/src/test/java/org/eclipse/jetty/websocket/javax/tests/server/TextStreamTest.java index 259d5da484a..543ca5ecabd 100644 --- a/jetty-websocket/websocket-javax-tests/src/test/java/org/eclipse/jetty/websocket/javax/tests/server/TextStreamTest.java +++ b/jetty-websocket/websocket-javax-tests/src/test/java/org/eclipse/jetty/websocket/javax/tests/server/TextStreamTest.java @@ -20,16 +20,29 @@ package org.eclipse.jetty.websocket.javax.tests.server; import java.io.IOException; import java.io.Reader; +import java.io.StringWriter; import java.io.Writer; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; +import java.util.Random; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import javax.websocket.ClientEndpoint; +import javax.websocket.ContainerProvider; +import javax.websocket.Endpoint; +import javax.websocket.EndpointConfig; +import javax.websocket.MessageHandler; import javax.websocket.OnMessage; import javax.websocket.Session; +import javax.websocket.WebSocketContainer; import javax.websocket.server.ServerContainer; import javax.websocket.server.ServerEndpoint; +import javax.websocket.server.ServerEndpointConfig; +import org.eclipse.jetty.util.BlockingArrayQueue; +import org.eclipse.jetty.util.IO; import org.eclipse.jetty.util.log.Log; import org.eclipse.jetty.util.log.Logger; import org.eclipse.jetty.websocket.core.CloseStatus; @@ -38,29 +51,39 @@ import org.eclipse.jetty.websocket.core.OpCode; import org.eclipse.jetty.websocket.javax.tests.DataUtils; import org.eclipse.jetty.websocket.javax.tests.Fuzzer; import org.eclipse.jetty.websocket.javax.tests.LocalServer; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; +import org.hamcrest.Matchers; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertNotNull; + public class TextStreamTest { private static final Logger LOG = Log.getLogger(TextStreamTest.class); + private static final BlockingArrayQueue serverEndpoints = new BlockingArrayQueue<>(); - private static LocalServer server; - private static ServerContainer container; + private LocalServer server; + private ServerContainer container; + private WebSocketContainer wsClient; - @BeforeAll - public static void startServer() throws Exception + @BeforeEach + public void startServer() throws Exception { server = new LocalServer(); server.start(); container = server.getServerContainer(); container.addEndpoint(ServerTextStreamer.class); + container.addEndpoint(ServerEndpointConfig.Builder.create(QueuedTextStreamer.class, "/test").build()); + container.addEndpoint(ServerEndpointConfig.Builder.create(QueuedPartialTextStreamer.class, "/partial").build()); + + wsClient = ContainerProvider.getWebSocketContainer(); } - @AfterAll - public static void stopServer() throws Exception + @AfterEach + public void stopServer() throws Exception { server.stop(); } @@ -145,6 +168,105 @@ public class TextStreamTest } } + @Test + public void testMessageOrdering() throws Exception + { + ClientTextStreamer client = new ClientTextStreamer(); + Session session = wsClient.connectToServer(client, server.getWsUri().resolve("/test")); + + final int numLoops = 20; + for (int i = 0; i < numLoops; i++) + { + session.getBasicRemote().sendText(Integer.toString(i)); + } + session.close(); + + QueuedTextStreamer queuedTextStreamer = serverEndpoints.poll(5, TimeUnit.SECONDS); + assertNotNull(queuedTextStreamer); + for (int i = 0; i < numLoops; i++) + { + String msg = queuedTextStreamer.messages.poll(5, TimeUnit.SECONDS); + assertThat(msg, Matchers.is(Integer.toString(i))); + } + } + + @Test + public void testFragmentedMessageOrdering() throws Exception + { + ClientTextStreamer client = new ClientTextStreamer(); + Session session = wsClient.connectToServer(client, server.getWsUri().resolve("/test")); + + final int numLoops = 20; + for (int i = 0; i < numLoops; i++) + { + session.getBasicRemote().sendText("firstFrame" + i, false); + session.getBasicRemote().sendText("|secondFrame" + i, false); + session.getBasicRemote().sendText("|finalFrame" + i, true); + } + session.close(); + + QueuedTextStreamer queuedTextStreamer = serverEndpoints.poll(5, TimeUnit.SECONDS); + assertNotNull(queuedTextStreamer); + for (int i = 0; i < numLoops; i++) + { + String msg = queuedTextStreamer.messages.poll(5, TimeUnit.SECONDS); + String expected = "firstFrame" + i + "|secondFrame" + i + "|finalFrame" + i; + assertThat(msg, Matchers.is(expected)); + } + } + + @Test + public void testMessageOrderingDoNotReadToEOF() throws Exception + { + ClientTextStreamer client = new ClientTextStreamer(); + Session session = wsClient.connectToServer(client, server.getWsUri().resolve("/partial")); + + final int numLoops = 20; + for (int i = 0; i < numLoops; i++) + { + session.getBasicRemote().sendText(i + "|-----"); + } + session.close(); + + QueuedTextStreamer queuedTextStreamer = serverEndpoints.poll(5, TimeUnit.SECONDS); + assertNotNull(queuedTextStreamer); + for (int i = 0; i < numLoops; i++) + { + String msg = queuedTextStreamer.messages.poll(5, TimeUnit.SECONDS); + assertThat(msg, Matchers.is(Integer.toString(i))); + } + } + + @ClientEndpoint + public static class ClientTextStreamer + { + private final CountDownLatch latch = new CountDownLatch(1); + private final StringBuilder output = new StringBuilder(); + + @OnMessage + public void echoed(Reader input) throws IOException + { + while (true) + { + int read = input.read(); + if (read < 0) + break; + output.append((char)read); + } + latch.countDown(); + } + + public char[] getEcho() + { + return output.toString().toCharArray(); + } + + public boolean await(long timeout, TimeUnit unit) throws InterruptedException + { + return latch.await(timeout, unit); + } + } + @ServerEndpoint("/echo") public static class ServerTextStreamer { @@ -166,4 +288,62 @@ public class TextStreamTest } } } + + public static class QueuedTextStreamer extends Endpoint implements MessageHandler.Whole + { + protected BlockingArrayQueue messages = new BlockingArrayQueue<>(); + + public QueuedTextStreamer() + { + serverEndpoints.add(this); + } + + @Override + public void onOpen(Session session, EndpointConfig config) + { + session.addMessageHandler(this); + } + + @Override + public void onMessage(Reader input) + { + try + { + Thread.sleep(Math.abs(new Random().nextLong() % 200)); + messages.add(IO.toString(input)); + } + catch (Exception e) + { + e.printStackTrace(); + } + } + } + + public static class QueuedPartialTextStreamer extends QueuedTextStreamer + { + @Override + public void onMessage(Reader input) + { + try + { + Thread.sleep(Math.abs(new Random().nextLong() % 200)); + + // Do not read to EOF but just the first '|'. + StringWriter writer = new StringWriter(); + while (true) + { + int read = input.read(); + if (read < 0 || read == '|') + break; + writer.write(read); + } + + messages.add(writer.toString()); + } + catch (Exception e) + { + e.printStackTrace(); + } + } + } } diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/CallbackBuffer.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/CallbackBuffer.java deleted file mode 100644 index afe76cfec22..00000000000 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/CallbackBuffer.java +++ /dev/null @@ -1,44 +0,0 @@ -// -// ======================================================================== -// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others. -// -// This program and the accompanying materials are made available under -// the terms of the Eclipse Public License 2.0 which is available at -// https://www.eclipse.org/legal/epl-2.0 -// -// This Source Code may also be made available under the following -// Secondary Licenses when the conditions for such availability set -// forth in the Eclipse Public License, v. 2.0 are satisfied: -// the Apache License v2.0 which is available at -// https://www.apache.org/licenses/LICENSE-2.0 -// -// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 -// ======================================================================== -// - -package org.eclipse.jetty.websocket.util.messages; - -import java.nio.ByteBuffer; -import java.util.Objects; - -import org.eclipse.jetty.util.BufferUtil; -import org.eclipse.jetty.util.Callback; - -public class CallbackBuffer -{ - public ByteBuffer buffer; - public Callback callback; - - public CallbackBuffer(Callback callback, ByteBuffer buffer) - { - Objects.requireNonNull(buffer, "buffer"); - this.callback = callback; - this.buffer = buffer; - } - - @Override - public String toString() - { - return String.format("CallbackBuffer[%s,%s]", BufferUtil.toDetailString(buffer), callback.getClass().getSimpleName()); - } -} diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/DispatchedMessageSink.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/DispatchedMessageSink.java index f4eab98533d..107d8046b40 100644 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/DispatchedMessageSink.java +++ b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/DispatchedMessageSink.java @@ -22,6 +22,7 @@ import java.lang.invoke.MethodHandle; import java.util.concurrent.CompletableFuture; import org.eclipse.jetty.util.Callback; +import org.eclipse.jetty.util.FutureCallback; import org.eclipse.jetty.websocket.core.CoreSession; import org.eclipse.jetty.websocket.core.Frame; @@ -93,11 +94,8 @@ import org.eclipse.jetty.websocket.core.Frame; * EOF stream.read EOF * RESUME(NEXT MSG) * - * - * @param the type of object to give to user function */ -@SuppressWarnings("Duplicates") -public abstract class DispatchedMessageSink extends AbstractMessageSink +public abstract class DispatchedMessageSink extends AbstractMessageSink { private CompletableFuture dispatchComplete; private MessageSink typeSink; @@ -114,14 +112,14 @@ public abstract class DispatchedMessageSink extends AbstractMessageSink if (typeSink == null) { typeSink = newSink(frame); - // Dispatch to end user function (will likely start with blocking for data/accept) dispatchComplete = new CompletableFuture<>(); + + // Dispatch to end user function (will likely start with blocking for data/accept) new Thread(() -> { - final T dispatchedType = (T)typeSink; try { - methodHandle.invoke(dispatchedType); + methodHandle.invoke(typeSink); dispatchComplete.complete(null); } catch (Throwable throwable) @@ -131,40 +129,21 @@ public abstract class DispatchedMessageSink extends AbstractMessageSink }).start(); } - final Callback frameCallback; - + Callback frameCallback = callback; if (frame.isFin()) { - CompletableFuture finComplete = new CompletableFuture<>(); - frameCallback = new Callback() + // This is the final frame we should wait for the frame callback and the dispatched thread. + Callback.Completable completableCallback = new Callback.Completable(); + frameCallback = completableCallback; + CompletableFuture.allOf(dispatchComplete, completableCallback).whenComplete((aVoid, throwable) -> { - @Override - public void failed(Throwable cause) - { - finComplete.completeExceptionally(cause); - } - - @Override - public void succeeded() - { - finComplete.complete(null); - } - }; - CompletableFuture.allOf(dispatchComplete, finComplete).whenComplete( - (aVoid, throwable) -> - { - typeSink = null; - dispatchComplete = null; - if (throwable != null) - callback.failed(throwable); - else - callback.succeeded(); - }); - } - else - { - // Non-fin-frame - frameCallback = callback; + typeSink = null; + dispatchComplete = null; + if (throwable != null) + callback.failed(throwable); + else + callback.succeeded(); + }); } typeSink.accept(frame, frameCallback); diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/InputStreamMessageSink.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/InputStreamMessageSink.java index b8413b1c1c9..161b367dd2e 100644 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/InputStreamMessageSink.java +++ b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/InputStreamMessageSink.java @@ -18,13 +18,12 @@ package org.eclipse.jetty.websocket.util.messages; -import java.io.InputStream; import java.lang.invoke.MethodHandle; import org.eclipse.jetty.websocket.core.CoreSession; import org.eclipse.jetty.websocket.core.Frame; -public class InputStreamMessageSink extends DispatchedMessageSink +public class InputStreamMessageSink extends DispatchedMessageSink { public InputStreamMessageSink(CoreSession session, MethodHandle methodHandle) { diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/MessageInputStream.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/MessageInputStream.java index 6e454bfb4d9..220f3afabc1 100644 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/MessageInputStream.java +++ b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/MessageInputStream.java @@ -21,10 +21,12 @@ package org.eclipse.jetty.websocket.util.messages; import java.io.IOException; import java.io.InputStream; import java.io.InterruptedIOException; -import java.util.ArrayDeque; -import java.util.Deque; +import java.nio.ByteBuffer; +import java.util.Objects; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import org.eclipse.jetty.util.BlockingArrayQueue; import org.eclipse.jetty.util.BufferUtil; import org.eclipse.jetty.util.Callback; import org.eclipse.jetty.util.log.Log; @@ -40,10 +42,11 @@ import org.eclipse.jetty.websocket.core.Frame; public class MessageInputStream extends InputStream implements MessageSink { private static final Logger LOG = Log.getLogger(MessageInputStream.class); - private static final CallbackBuffer EOF = new CallbackBuffer(Callback.NOOP, BufferUtil.EMPTY_BUFFER); - private final Deque buffers = new ArrayDeque<>(2); + private static final Entry EOF = new Entry(BufferUtil.EMPTY_BUFFER, Callback.NOOP); + private final BlockingArrayQueue buffers = new BlockingArrayQueue<>(); private final AtomicBoolean closed = new AtomicBoolean(false); - private CallbackBuffer activeFrame; + private Entry currentEntry; + private long timeoutMs = -1; @Override public void accept(Frame frame, Callback callback) @@ -51,119 +54,20 @@ public class MessageInputStream extends InputStream implements MessageSink if (LOG.isDebugEnabled()) LOG.debug("accepting {}", frame); - // If closed, we should just toss incoming payloads into the bit bucket. - if (closed.get()) - { - callback.failed(new IOException("Already Closed")); - return; - } - - if (!frame.hasPayload() && !frame.isFin()) + // If closed or we have no payload, request the next frame. + if (closed.get() || (!frame.hasPayload() && !frame.isFin())) { callback.succeeded(); return; } - synchronized (buffers) - { - boolean notify = false; - if (frame.hasPayload()) - { - buffers.offer(new CallbackBuffer(callback, frame.getPayload())); - notify = true; - } - else - { - // We cannot wake up blocking read for a zero length frame. - callback.succeeded(); - } + if (frame.hasPayload()) + buffers.add(new Entry(frame.getPayload(), callback)); + else + callback.succeeded(); - if (frame.isFin()) - { - buffers.offer(EOF); - notify = true; - } - - if (notify) - { - // notify other thread - buffers.notify(); - } - } - } - - @Override - public void close() throws IOException - { - if (LOG.isDebugEnabled()) - LOG.debug("close()"); - - if (closed.compareAndSet(false, true)) - { - synchronized (buffers) - { - buffers.offer(EOF); - buffers.notify(); - } - } - super.close(); - } - - public CallbackBuffer getActiveFrame() throws InterruptedIOException - { - if (activeFrame == null) - { - // sync and poll queue - CallbackBuffer result; - synchronized (buffers) - { - try - { - while ((result = buffers.poll()) == null) - { - // TODO: handle read timeout here? - buffers.wait(); - } - } - catch (InterruptedException e) - { - shutdown(); - throw new InterruptedIOException(); - } - } - activeFrame = result; - } - - return activeFrame; - } - - private void shutdown() - { - if (LOG.isDebugEnabled()) - LOG.debug("shutdown()"); - synchronized (buffers) - { - closed.set(true); - Throwable cause = new IOException("Shutdown"); - for (CallbackBuffer buffer : buffers) - { - buffer.callback.failed(cause); - } - // Removed buffers that may have remained in the queue. - buffers.clear(); - } - } - - @Override - public void mark(int readlimit) - { - // Not supported. - } - - @Override - public boolean markSupported() - { - return false; + if (frame.isFin()) + buffers.add(EOF); } @Override @@ -185,14 +89,9 @@ public class MessageInputStream extends InputStream implements MessageSink public int read(final byte[] b, final int off, final int len) throws IOException { if (closed.get()) - { - if (LOG.isDebugEnabled()) - LOG.debug("Stream closed"); return -1; - } - - CallbackBuffer result = getActiveFrame(); + Entry result = getCurrentEntry(); if (LOG.isDebugEnabled()) LOG.debug("result = {}", result); @@ -207,10 +106,9 @@ public class MessageInputStream extends InputStream implements MessageSink // We have content int fillLen = Math.min(result.buffer.remaining(), len); result.buffer.get(b, off, fillLen); - if (!result.buffer.hasRemaining()) { - activeFrame = null; + currentEntry = null; result.callback.succeeded(); } @@ -219,8 +117,94 @@ public class MessageInputStream extends InputStream implements MessageSink } @Override - public void reset() throws IOException + public void close() throws IOException { - throw new IOException("reset() not supported"); + if (LOG.isDebugEnabled()) + LOG.debug("close()"); + + if (closed.compareAndSet(false, true)) + { + synchronized (buffers) + { + buffers.offer(EOF); + buffers.notify(); + } + } + + super.close(); + } + + private void shutdown() + { + if (LOG.isDebugEnabled()) + LOG.debug("shutdown()"); + + synchronized (this) + { + closed.set(true); + Throwable cause = new IOException("Shutdown"); + for (Entry buffer : buffers) + { + buffer.callback.failed(cause); + } + + // Removed buffers that may have remained in the queue. + buffers.clear(); + } + } + + public void setTimeout(long timeoutMs) + { + this.timeoutMs = timeoutMs; + } + + private Entry getCurrentEntry() throws IOException + { + if (currentEntry != null) + return currentEntry; + + // sync and poll queue + try + { + if (LOG.isDebugEnabled()) + LOG.debug("Waiting {} ms to read", timeoutMs); + if (timeoutMs < 0) + { + // Wait forever until a buffer is available. + currentEntry = buffers.take(); + } + else + { + // Wait at most for the given timeout. + currentEntry = buffers.poll(timeoutMs, TimeUnit.MILLISECONDS); + if (currentEntry == null) + throw new IOException(String.format("Read timeout: %,dms expired", timeoutMs)); + } + } + catch (InterruptedException e) + { + shutdown(); + throw new InterruptedIOException(); + } + + return currentEntry; + } + + private static class Entry + { + public ByteBuffer buffer; + public Callback callback; + + public Entry(ByteBuffer buffer, Callback callback) + { + this.buffer = Objects.requireNonNull(buffer); + this.callback = callback; + } + + @Override + public String toString() + { + return String.format("Entry[%s,%s]", BufferUtil.toDetailString(buffer), callback.getClass().getSimpleName()); + } } } diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/MessageReader.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/MessageReader.java index b97eb28b1ca..1eac8c3d680 100644 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/MessageReader.java +++ b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/MessageReader.java @@ -33,10 +33,15 @@ public class MessageReader extends InputStreamReader implements MessageSink { private final MessageInputStream stream; - public MessageReader(MessageInputStream stream) + public MessageReader() { - super(stream, StandardCharsets.UTF_8); - this.stream = stream; + this(new MessageInputStream()); + } + + private MessageReader(MessageInputStream inputStream) + { + super(inputStream, StandardCharsets.UTF_8); + this.stream = inputStream; } @Override diff --git a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ReaderMessageSink.java b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ReaderMessageSink.java index 7cfeb14f7b6..dc17238ca81 100644 --- a/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ReaderMessageSink.java +++ b/jetty-websocket/websocket-util/src/main/java/org/eclipse/jetty/websocket/util/messages/ReaderMessageSink.java @@ -18,13 +18,12 @@ package org.eclipse.jetty.websocket.util.messages; -import java.io.Reader; import java.lang.invoke.MethodHandle; import org.eclipse.jetty.websocket.core.CoreSession; import org.eclipse.jetty.websocket.core.Frame; -public class ReaderMessageSink extends DispatchedMessageSink +public class ReaderMessageSink extends DispatchedMessageSink { public ReaderMessageSink(CoreSession session, MethodHandle methodHandle) { @@ -34,6 +33,6 @@ public class ReaderMessageSink extends DispatchedMessageSink @Override public MessageReader newSink(Frame frame) { - return new MessageReader(new MessageInputStream()); + return new MessageReader(); } }