diff --git a/jetty-websocket/websocket-common/src/test/java/org/eclipse/jetty/websocket/common/message/MessageInputStreamTest.java b/jetty-websocket/websocket-common/src/test/java/org/eclipse/jetty/websocket/common/message/MessageInputStreamTest.java index 70c2a7527b6..4266ed51542 100644 --- a/jetty-websocket/websocket-common/src/test/java/org/eclipse/jetty/websocket/common/message/MessageInputStreamTest.java +++ b/jetty-websocket/websocket-common/src/test/java/org/eclipse/jetty/websocket/common/message/MessageInputStreamTest.java @@ -24,12 +24,15 @@ import java.nio.charset.StandardCharsets; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import org.eclipse.jetty.io.ByteBufferPool; import org.eclipse.jetty.io.MappedByteBufferPool; import org.eclipse.jetty.toolchain.test.jupiter.WorkDir; import org.eclipse.jetty.toolchain.test.jupiter.WorkDirExtension; +import org.eclipse.jetty.util.BlockingArrayQueue; import org.eclipse.jetty.util.BufferUtil; +import org.eclipse.jetty.websocket.api.SuspendToken; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -37,6 +40,8 @@ import org.junit.jupiter.api.extension.ExtendWith; import static java.time.Duration.ofSeconds; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; @ExtendWith(WorkDirExtension.class) public class MessageInputStreamTest @@ -48,116 +53,107 @@ public class MessageInputStreamTest @Test public void testBasicAppendRead() throws IOException { - try (MessageInputStream stream = new MessageInputStream(new EmptySession())) - { - Assertions.assertTimeoutPreemptively(ofSeconds(5), () -> - { - // Append a single message (simple, short) - ByteBuffer payload = BufferUtil.toBuffer("Hello World", StandardCharsets.UTF_8); - boolean fin = true; - stream.appendFrame(payload, fin); + StreamTestSession session = new StreamTestSession(); + MessageInputStream stream = new MessageInputStream(session); + session.setMessageInputStream(stream); - // Read entire message it from the stream. - byte[] buf = new byte[32]; - int len = stream.read(buf); - String message = new String(buf, 0, len, StandardCharsets.UTF_8); + // Append a single message (simple, short) + ByteBuffer payload = BufferUtil.toBuffer("Hello World!", StandardCharsets.UTF_8); + session.addContent(payload, true); + session.provideContent(); - // Test it - assertThat("Message", message, is("Hello World")); - }); - } + // Read entire message it from the stream. + byte[] buf = new byte[32]; + int len = stream.read(buf); + String message = new String(buf, 0, len, StandardCharsets.UTF_8); + + // Test it + assertThat("Message", message, is("Hello World!")); } @Test public void testBlockOnRead() throws Exception { - try (MessageInputStream stream = new MessageInputStream(new EmptySession())) + StreamTestSession session = new StreamTestSession(); + MessageInputStream stream = new MessageInputStream(session); + session.setMessageInputStream(stream); + new Thread(session::provideContent).start(); + + final AtomicBoolean hadError = new AtomicBoolean(false); + final CountDownLatch startLatch = new CountDownLatch(1); + + // This thread fills the stream (from the "worker" thread) + // But slowly (intentionally). + new Thread(() -> { - final AtomicBoolean hadError = new AtomicBoolean(false); - final CountDownLatch startLatch = new CountDownLatch(1); - - // This thread fills the stream (from the "worker" thread) - // But slowly (intentionally). - new Thread(new Runnable() + try { - @Override - public void run() - { - try - { - startLatch.countDown(); - boolean fin = false; - TimeUnit.MILLISECONDS.sleep(200); - stream.appendFrame(BufferUtil.toBuffer("Saved", StandardCharsets.UTF_8), fin); - TimeUnit.MILLISECONDS.sleep(200); - stream.appendFrame(BufferUtil.toBuffer(" by ", StandardCharsets.UTF_8), fin); - fin = true; - TimeUnit.MILLISECONDS.sleep(200); - stream.appendFrame(BufferUtil.toBuffer("Zero", StandardCharsets.UTF_8), fin); - } - catch (IOException | InterruptedException e) - { - hadError.set(true); - e.printStackTrace(System.err); - } - } - }).start(); - - Assertions.assertTimeoutPreemptively(ofSeconds(5), () -> + startLatch.countDown(); + TimeUnit.MILLISECONDS.sleep(200); + session.addContent("Saved", false); + TimeUnit.MILLISECONDS.sleep(200); + session.addContent(" by ", false); + TimeUnit.MILLISECONDS.sleep(200); + session.addContent("Zero", false); + TimeUnit.MILLISECONDS.sleep(200); + session.addContent("", true); + } + catch (Throwable t) { - // wait for thread to start - startLatch.await(); + hadError.set(true); + t.printStackTrace(System.err); + } + }).start(); - // Read it from the stream. - byte[] buf = new byte[32]; - int len = stream.read(buf); - String message = new String(buf, 0, len, StandardCharsets.UTF_8); + Assertions.assertTimeoutPreemptively(ofSeconds(5), () -> + { + // wait for thread to start + startLatch.await(); - // Test it - assertThat("Error when appending", hadError.get(), is(false)); - assertThat("Message", message, is("Saved by Zero")); - }); - } + // Read it from the stream. + byte[] buf = new byte[32]; + int len = stream.read(buf); + String message = new String(buf, 0, len, StandardCharsets.UTF_8); + + // Test it + assertThat("Error when appending", hadError.get(), is(false)); + assertThat("Message", message, is("Saved by Zero")); + }); } @Test public void testBlockOnReadInitial() throws IOException { - try (MessageInputStream stream = new MessageInputStream(new EmptySession())) + StreamTestSession session = new StreamTestSession(); + MessageInputStream stream = new MessageInputStream(session); + session.setMessageInputStream(stream); + session.addContent("I will conquer", true); + + AtomicReference error = new AtomicReference<>(); + new Thread(() -> { - final AtomicBoolean hadError = new AtomicBoolean(false); - - new Thread(new Runnable() + try { - @Override - public void run() - { - try - { - boolean fin = true; - // wait for a little bit before populating buffers - TimeUnit.MILLISECONDS.sleep(400); - stream.appendFrame(BufferUtil.toBuffer("I will conquer", StandardCharsets.UTF_8), fin); - } - catch (IOException | InterruptedException e) - { - hadError.set(true); - e.printStackTrace(System.err); - } - } - }).start(); - - Assertions.assertTimeoutPreemptively(ofSeconds(10), () -> + // wait for a little bit before initiating write to stream + TimeUnit.MILLISECONDS.sleep(1000); + session.provideContent(); + } + catch (Throwable t) { - // Read byte from stream. - int b = stream.read(); - // Should be a byte, blocking till byte received. + error.set(t); + t.printStackTrace(System.err); + } + }).start(); - // Test it - assertThat("Error when appending", hadError.get(), is(false)); - assertThat("Initial byte", b, is((int)'I')); - }); - } + Assertions.assertTimeoutPreemptively(ofSeconds(10), () -> + { + // Read byte from stream, block until byte received. + int b = stream.read(); + assertThat("Initial byte", b, is((int)'I')); + + // No error occurred. + assertNull(error.get()); + }); } @Test @@ -167,89 +163,160 @@ public class MessageInputStreamTest { final AtomicBoolean hadError = new AtomicBoolean(false); - new Thread(new Runnable() + new Thread(() -> { - @Override - public void run() + try { - try - { - // wait for a little bit before sending input closed - TimeUnit.MILLISECONDS.sleep(400); - stream.messageComplete(); - } - catch (InterruptedException e) - { - hadError.set(true); - e.printStackTrace(System.err); - } + // wait for a little bit before sending input closed + TimeUnit.MILLISECONDS.sleep(1000); + stream.messageComplete(); + } + catch (InterruptedException e) + { + hadError.set(true); + e.printStackTrace(System.err); } }).start(); Assertions.assertTimeoutPreemptively(ofSeconds(10), () -> { - // Read byte from stream. + // Read byte from stream. Should be a -1, indicating the end of the stream. int b = stream.read(); - // Should be a -1, indicating the end of the stream. - - // Test it - assertThat("Error when appending", hadError.get(), is(false)); assertThat("Initial byte", b, is(-1)); + + // No error occurred. + assertThat("Error when appending", hadError.get(), is(false)); }); } } @Test - public void testAppendEmptyPayloadRead() throws IOException + public void testSplitMessageWithEmptyPayloads() throws IOException { - try (MessageInputStream stream = new MessageInputStream(new EmptySession())) - { - Assertions.assertTimeoutPreemptively(ofSeconds(10), () -> - { - // Append parts of message - ByteBuffer msg1 = BufferUtil.toBuffer("Hello ", StandardCharsets.UTF_8); - ByteBuffer msg2 = ByteBuffer.allocate(0); // what is being tested - ByteBuffer msg3 = BufferUtil.toBuffer("World", StandardCharsets.UTF_8); + StreamTestSession session = new StreamTestSession(); + MessageInputStream stream = new MessageInputStream(session); + session.setMessageInputStream(stream); - stream.appendFrame(msg1, false); - stream.appendFrame(msg2, false); - stream.appendFrame(msg3, true); + session.addContent("", false); + session.addContent("Hello", false); + session.addContent("", false); + session.addContent(" World", false); + session.addContent("!", false); + session.addContent("", true); + session.provideContent(); - // Read entire message it from the stream. - byte[] buf = new byte[32]; - int len = stream.read(buf); - String message = new String(buf, 0, len, StandardCharsets.UTF_8); + // Read entire message it from the stream. + byte[] buf = new byte[32]; + int len = stream.read(buf); + String message = new String(buf, 0, len, StandardCharsets.UTF_8); - // Test it - assertThat("Message", message, is("Hello World")); - }); - } + // Test it + assertThat("Message", message, is("Hello World!")); } @Test - public void testAppendNullPayloadRead() throws IOException + public void testReadBeforeFirstAppend() throws IOException { - try (MessageInputStream stream = new MessageInputStream(new EmptySession())) + StreamTestSession session = new StreamTestSession(); + MessageInputStream stream = new MessageInputStream(session); + session.setMessageInputStream(stream); + + // Append a single message (simple, short) + session.addContent(BufferUtil.EMPTY_BUFFER, false); + session.addContent("Hello World", true); + + new Thread(() -> { - Assertions.assertTimeoutPreemptively(ofSeconds(10), () -> + try { - // Append parts of message - ByteBuffer msg1 = BufferUtil.toBuffer("Hello ", StandardCharsets.UTF_8); - ByteBuffer msg2 = null; // what is being tested - ByteBuffer msg3 = BufferUtil.toBuffer("World", StandardCharsets.UTF_8); + Thread.sleep(2000); + session.provideContent(); + } + catch (Exception e) + { + throw new RuntimeException(e); + } + }).start(); - stream.appendFrame(msg1, false); - stream.appendFrame(msg2, false); - stream.appendFrame(msg3, true); + // Read entire message it from the stream. + byte[] buf = new byte[32]; + int len = stream.read(buf); + String message = new String(buf, 0, len, StandardCharsets.UTF_8); - // Read entire message it from the stream. - byte[] buf = new byte[32]; - int len = stream.read(buf); - String message = new String(buf, 0, len, StandardCharsets.UTF_8); + // Test it + assertThat("Message", message, is("Hello World")); + } - // Test it - assertThat("Message", message, is("Hello World")); - }); + public static class StreamTestSession extends EmptySession + { + private static final ByteBuffer EOF = BufferUtil.allocate(0); + private final AtomicBoolean suspended = new AtomicBoolean(false); + private BlockingArrayQueue contentQueue = new BlockingArrayQueue<>(); + private MessageInputStream stream; + + public void setMessageInputStream(MessageInputStream stream) + { + this.stream = stream; + } + + public void addContent(String content, boolean last) + { + addContent(BufferUtil.toBuffer(content, StandardCharsets.UTF_8), last); + } + + public void addContent(ByteBuffer content, boolean last) + { + contentQueue.add(content); + if (last) + contentQueue.add(EOF); + } + + public void provideContent() + { + pollAndAppendFrame(); + } + + @Override + public void resume() + { + if (!suspended.compareAndSet(true, false)) + throw new IllegalStateException(); + pollAndAppendFrame(); + } + + @Override + public SuspendToken suspend() + { + if (!suspended.compareAndSet(false, true)) + throw new IllegalStateException(); + return super.suspend(); + } + + private void pollAndAppendFrame() + { + try + { + while (true) + { + ByteBuffer content = contentQueue.poll(10, TimeUnit.SECONDS); + assertNotNull(content); + + boolean eof = (content == EOF); + stream.appendFrame(content, eof); + if (eof) + { + stream.messageComplete(); + break; + } + + if (suspended.get()) + break; + } + } + catch (Exception e) + { + throw new RuntimeException(e); + } } } }