diff --git a/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/CoreSession.java b/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/CoreSession.java index c9f6bae28b3..4c662787244 100644 --- a/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/CoreSession.java +++ b/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/CoreSession.java @@ -121,21 +121,35 @@ public interface CoreSession extends OutgoingFrames, IncomingFrames, Configurati void flush(Callback callback); /** - * Initiate close handshake, no payload (no declared status code or reason phrase) + * Initiate close handshake, no payload (no declared status code or reason phrase). * - * @param callback the callback to track close frame sent (or failed) + * @param callback the callback to track close frame sent (or failed). */ void close(Callback callback); /** - * Initiate close handshake with provide status code and optional reason phrase. + * Initiate close handshake with provided status code and optional reason phrase. * - * @param statusCode the status code (should be a valid status code that can be sent) - * @param reason optional reason phrase (will be truncated automatically by implementation to fit within limits of protocol) - * @param callback the callback to track close frame sent (or failed) + * @param statusCode the status code (should be a valid status code that can be sent). + * @param reason optional reason phrase (will be truncated automatically by implementation to fit within limits of protocol). + * @param callback the callback to track close frame sent (or failed). */ void close(int statusCode, String reason, Callback callback); + /** + * Initiate close handshake with a provided {@link CloseStatus}. + * + * @param closeStatus the close status containing (statusCode, reason, and optional {@link Throwable} cause). + * @param callback the callback to track close frame sent (or failed). + */ + default void close(CloseStatus closeStatus, Callback callback) + { + if (this instanceof WebSocketCoreSession coreSession) + coreSession.close(closeStatus, callback); + else + close(closeStatus.getCode(), closeStatus.getReason(), callback); + } + /** * Issue a harsh abort of the underlying connection. *

@@ -287,6 +301,12 @@ public interface CoreSession extends OutgoingFrames, IncomingFrames, Configurati callback.succeeded(); } + @Override + public void close(CloseStatus closeStatus, Callback callback) + { + callback.succeeded(); + } + @Override public void demand() { diff --git a/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/WebSocketCoreSession.java b/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/WebSocketCoreSession.java index 835fe66f66c..e328e018bb0 100644 --- a/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/WebSocketCoreSession.java +++ b/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/WebSocketCoreSession.java @@ -194,31 +194,20 @@ public class WebSocketCoreSession implements CoreSession, Dumpable this.connection = connection; } - /** - * Send Close Frame with no payload. - * - * @param callback the callback on successful send of close frame - */ @Override public void close(Callback callback) { close(NO_CODE, callback); } - /** - * Send Close Frame with specified Status Code and optional Reason - * - * @param statusCode a valid WebSocket status code - * @param reason an optional reason phrase - * @param callback the callback on successful send of close frame - */ @Override public void close(int statusCode, String reason, Callback callback) { close(new CloseStatus(statusCode, reason), callback); } - private void close(CloseStatus closeStatus, Callback callback) + @Override + public void close(CloseStatus closeStatus, Callback callback) { sendFrame(closeStatus.toFrame(), callback, false); } diff --git a/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/messages/DispatchedMessageSink.java b/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/messages/DispatchedMessageSink.java index c1530296691..7deb4ed60d8 100644 --- a/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/messages/DispatchedMessageSink.java +++ b/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/messages/DispatchedMessageSink.java @@ -18,10 +18,12 @@ import java.io.InputStream; import java.io.Reader; import java.lang.invoke.MethodHandle; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.Executor; import org.eclipse.jetty.util.Callback; import org.eclipse.jetty.util.IO; +import org.eclipse.jetty.websocket.core.CloseStatus; import org.eclipse.jetty.websocket.core.CoreSession; import org.eclipse.jetty.websocket.core.Frame; @@ -100,7 +102,17 @@ public abstract class DispatchedMessageSink extends AbstractMessageSink // frame, while this MessageSink manages the demand when both // the last frame and the dispatched thread are completed. if (failure == null) + { autoDemand(); + } + else + { + if (failure instanceof CompletionException completionException) + failure = completionException.getCause(); + + CloseStatus closeStatus = new CloseStatus(CloseStatus.SERVER_ERROR, failure); + getCoreSession().close(closeStatus, Callback.NOOP); + } }); } diff --git a/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/listeners/MessageReaderErrorTest.java b/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/listeners/MessageReaderErrorTest.java new file mode 100644 index 00000000000..1e8cf2ecad7 --- /dev/null +++ b/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/listeners/MessageReaderErrorTest.java @@ -0,0 +1,194 @@ +// +// ======================================================================== +// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +// ======================================================================== +// + +package org.eclipse.jetty.websocket.tests.listeners; + +import java.io.Reader; +import java.net.URI; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.server.ServerConnector; +import org.eclipse.jetty.util.BlockingArrayQueue; +import org.eclipse.jetty.util.IO; +import org.eclipse.jetty.util.Utf8StringBuilder; +import org.eclipse.jetty.websocket.api.Callback; +import org.eclipse.jetty.websocket.api.Session; +import org.eclipse.jetty.websocket.api.StatusCode; +import org.eclipse.jetty.websocket.api.annotations.OnWebSocketClose; +import org.eclipse.jetty.websocket.api.annotations.OnWebSocketError; +import org.eclipse.jetty.websocket.api.annotations.OnWebSocketMessage; +import org.eclipse.jetty.websocket.api.annotations.WebSocket; +import org.eclipse.jetty.websocket.client.WebSocketClient; +import org.eclipse.jetty.websocket.server.WebSocketUpgradeHandler; +import org.eclipse.jetty.websocket.tests.EventSocket; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class MessageReaderErrorTest +{ + private Server _server; + private ServerConnector _connector; + private WebSocketClient _client; + private WebSocketUpgradeHandler _upgradeHandler; + + @BeforeEach + public void before() throws Exception + { + _server = new Server(); + _connector = new ServerConnector(_server); + _server.addConnector(_connector); + + _upgradeHandler = WebSocketUpgradeHandler.from(_server); + _server.setHandler(_upgradeHandler); + + _server.start(); + _client = new WebSocketClient(); + _client.start(); + } + + @AfterEach + public void after() throws Exception + { + _client.stop(); + _server.stop(); + } + + @WebSocket + public static class ReaderErrorEndpoint + { + public final int toRead; + public CountDownLatch closeLatch = new CountDownLatch(1); + public BlockingQueue textMessages = new BlockingArrayQueue<>(); + public volatile String closeReason; + public volatile int closeCode = StatusCode.UNDEFINED; + public volatile Throwable error = null; + + public ReaderErrorEndpoint() + { + this(-1); + } + + public ReaderErrorEndpoint(int read) + { + toRead = read; + } + + @OnWebSocketMessage + public void onMessage(Reader reader) throws Exception + { + if (toRead < 0) + { + textMessages.add(IO.toString(reader)); + } + else + { + Utf8StringBuilder sb = new Utf8StringBuilder(); + for (int i = 0; i < toRead; i++) + { + int read = reader.read(); + if (read < 0) + break; + sb.append((byte)read); + } + textMessages.add(sb.build()); + } + + // This reader will be dispatched to another thread and won't be the thread reading from the connection, + // however throwing from here should still fail the websocket connection. + throw new IllegalStateException("failed from test"); + } + + @OnWebSocketError + public void onError(Throwable t) + { + error = t; + } + + @OnWebSocketClose + public void onClose(int code, String reason) + { + closeCode = code; + closeReason = reason; + closeLatch.countDown(); + } + } + + @Test + public void testReaderOnError() throws Exception + { + ReaderErrorEndpoint serverEndpoint = new ReaderErrorEndpoint(); + _upgradeHandler.getServerWebSocketContainer() + .addMapping("/", (req, resp, cb) -> serverEndpoint); + + URI uri = URI.create("ws://localhost:" + _connector.getLocalPort()); + EventSocket clientEndpoint = new EventSocket(); + Session session = _client.connect(clientEndpoint, uri).get(5, TimeUnit.SECONDS); + session.sendPartialText("hel", false, Callback.NOOP); + session.sendPartialText("lo ", false, Callback.NOOP); + session.sendPartialText("wor", false, Callback.NOOP); + session.sendPartialText("ld", false, Callback.NOOP); + session.sendPartialText(null, true, Callback.NOOP); + + assertThat(serverEndpoint.textMessages.poll(5, TimeUnit.SECONDS), equalTo("hello world")); + + assertTrue(serverEndpoint.closeLatch.await(5, TimeUnit.SECONDS)); + assertThat(serverEndpoint.closeCode, equalTo(StatusCode.SERVER_ERROR)); + assertThat(serverEndpoint.closeReason, containsString("failed from test")); + assertThat(serverEndpoint.error, instanceOf(IllegalStateException.class)); + + assertTrue(clientEndpoint.closeLatch.await(5, TimeUnit.SECONDS)); + assertThat(clientEndpoint.closeCode, equalTo(StatusCode.SERVER_ERROR)); + assertThat(clientEndpoint.closeReason, containsString("failed from test")); + assertNull(clientEndpoint.error); + } + + @Test + public void testReaderOnErrorPartialRead() throws Exception + { + ReaderErrorEndpoint serverEndpoint = new ReaderErrorEndpoint(5); + _upgradeHandler.getServerWebSocketContainer() + .addMapping("/", (req, resp, cb) -> serverEndpoint); + + URI uri = URI.create("ws://localhost:" + _connector.getLocalPort()); + EventSocket clientEndpoint = new EventSocket(); + Session session = _client.connect(clientEndpoint, uri).get(5, TimeUnit.SECONDS); + session.sendPartialText("hel", false, Callback.NOOP); + session.sendPartialText("lo ", false, Callback.NOOP); + session.sendPartialText("wor", false, Callback.NOOP); + session.sendPartialText("ld", false, Callback.NOOP); + session.sendPartialText(null, true, Callback.NOOP); + + assertThat(serverEndpoint.textMessages.poll(5, TimeUnit.SECONDS), equalTo("hello")); + + assertTrue(serverEndpoint.closeLatch.await(5, TimeUnit.SECONDS)); + assertThat(serverEndpoint.closeCode, equalTo(StatusCode.SERVER_ERROR)); + assertThat(serverEndpoint.closeReason, containsString("failed from test")); + assertThat(serverEndpoint.error, instanceOf(IllegalStateException.class)); + + assertTrue(clientEndpoint.closeLatch.await(5, TimeUnit.SECONDS)); + assertThat(clientEndpoint.closeCode, equalTo(StatusCode.SERVER_ERROR)); + assertThat(clientEndpoint.closeReason, containsString("failed from test")); + assertNull(clientEndpoint.error); + } +}