Issue #11275 - explicitly close websocket endpoint after error from DispatchedMessageSink (#11343)

* Now properly handling errors
* Added test for partial read

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
Lachlan 2024-01-30 01:13:25 +11:00 committed by GitHub
parent a33dd59b21
commit 554d5b19d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 234 additions and 19 deletions

View File

@ -121,21 +121,35 @@ public interface CoreSession extends OutgoingFrames, IncomingFrames, Configurati
void flush(Callback callback);
/**
* Initiate close handshake, no payload (no declared status code or reason phrase)
* Initiate close handshake, no payload (no declared status code or reason phrase).
*
* @param callback the callback to track close frame sent (or failed)
* @param callback the callback to track close frame sent (or failed).
*/
void close(Callback callback);
/**
* Initiate close handshake with provide status code and optional reason phrase.
* Initiate close handshake with provided status code and optional reason phrase.
*
* @param statusCode the status code (should be a valid status code that can be sent)
* @param reason optional reason phrase (will be truncated automatically by implementation to fit within limits of protocol)
* @param callback the callback to track close frame sent (or failed)
* @param statusCode the status code (should be a valid status code that can be sent).
* @param reason optional reason phrase (will be truncated automatically by implementation to fit within limits of protocol).
* @param callback the callback to track close frame sent (or failed).
*/
void close(int statusCode, String reason, Callback callback);
/**
* Initiate close handshake with a provided {@link CloseStatus}.
*
* @param closeStatus the close status containing (statusCode, reason, and optional {@link Throwable} cause).
* @param callback the callback to track close frame sent (or failed).
*/
default void close(CloseStatus closeStatus, Callback callback)
{
if (this instanceof WebSocketCoreSession coreSession)
coreSession.close(closeStatus, callback);
else
close(closeStatus.getCode(), closeStatus.getReason(), callback);
}
/**
* Issue a harsh abort of the underlying connection.
* <p>
@ -287,6 +301,12 @@ public interface CoreSession extends OutgoingFrames, IncomingFrames, Configurati
callback.succeeded();
}
@Override
public void close(CloseStatus closeStatus, Callback callback)
{
callback.succeeded();
}
@Override
public void demand()
{

View File

@ -194,31 +194,20 @@ public class WebSocketCoreSession implements CoreSession, Dumpable
this.connection = connection;
}
/**
* Send Close Frame with no payload.
*
* @param callback the callback on successful send of close frame
*/
@Override
public void close(Callback callback)
{
close(NO_CODE, callback);
}
/**
* Send Close Frame with specified Status Code and optional Reason
*
* @param statusCode a valid WebSocket status code
* @param reason an optional reason phrase
* @param callback the callback on successful send of close frame
*/
@Override
public void close(int statusCode, String reason, Callback callback)
{
close(new CloseStatus(statusCode, reason), callback);
}
private void close(CloseStatus closeStatus, Callback callback)
@Override
public void close(CloseStatus closeStatus, Callback callback)
{
sendFrame(closeStatus.toFrame(), callback, false);
}

View File

@ -18,10 +18,12 @@ import java.io.InputStream;
import java.io.Reader;
import java.lang.invoke.MethodHandle;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.Executor;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.websocket.core.CloseStatus;
import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.Frame;
@ -100,7 +102,17 @@ public abstract class DispatchedMessageSink extends AbstractMessageSink
// frame, while this MessageSink manages the demand when both
// the last frame and the dispatched thread are completed.
if (failure == null)
{
autoDemand();
}
else
{
if (failure instanceof CompletionException completionException)
failure = completionException.getCause();
CloseStatus closeStatus = new CloseStatus(CloseStatus.SERVER_ERROR, failure);
getCoreSession().close(closeStatus, Callback.NOOP);
}
});
}

View File

@ -0,0 +1,194 @@
//
// ========================================================================
// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License v. 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.websocket.tests.listeners;
import java.io.Reader;
import java.net.URI;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.util.BlockingArrayQueue;
import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.util.Utf8StringBuilder;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.StatusCode;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketClose;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketError;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketMessage;
import org.eclipse.jetty.websocket.api.annotations.WebSocket;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.WebSocketUpgradeHandler;
import org.eclipse.jetty.websocket.tests.EventSocket;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class MessageReaderErrorTest
{
private Server _server;
private ServerConnector _connector;
private WebSocketClient _client;
private WebSocketUpgradeHandler _upgradeHandler;
@BeforeEach
public void before() throws Exception
{
_server = new Server();
_connector = new ServerConnector(_server);
_server.addConnector(_connector);
_upgradeHandler = WebSocketUpgradeHandler.from(_server);
_server.setHandler(_upgradeHandler);
_server.start();
_client = new WebSocketClient();
_client.start();
}
@AfterEach
public void after() throws Exception
{
_client.stop();
_server.stop();
}
@WebSocket
public static class ReaderErrorEndpoint
{
public final int toRead;
public CountDownLatch closeLatch = new CountDownLatch(1);
public BlockingQueue<String> textMessages = new BlockingArrayQueue<>();
public volatile String closeReason;
public volatile int closeCode = StatusCode.UNDEFINED;
public volatile Throwable error = null;
public ReaderErrorEndpoint()
{
this(-1);
}
public ReaderErrorEndpoint(int read)
{
toRead = read;
}
@OnWebSocketMessage
public void onMessage(Reader reader) throws Exception
{
if (toRead < 0)
{
textMessages.add(IO.toString(reader));
}
else
{
Utf8StringBuilder sb = new Utf8StringBuilder();
for (int i = 0; i < toRead; i++)
{
int read = reader.read();
if (read < 0)
break;
sb.append((byte)read);
}
textMessages.add(sb.build());
}
// This reader will be dispatched to another thread and won't be the thread reading from the connection,
// however throwing from here should still fail the websocket connection.
throw new IllegalStateException("failed from test");
}
@OnWebSocketError
public void onError(Throwable t)
{
error = t;
}
@OnWebSocketClose
public void onClose(int code, String reason)
{
closeCode = code;
closeReason = reason;
closeLatch.countDown();
}
}
@Test
public void testReaderOnError() throws Exception
{
ReaderErrorEndpoint serverEndpoint = new ReaderErrorEndpoint();
_upgradeHandler.getServerWebSocketContainer()
.addMapping("/", (req, resp, cb) -> serverEndpoint);
URI uri = URI.create("ws://localhost:" + _connector.getLocalPort());
EventSocket clientEndpoint = new EventSocket();
Session session = _client.connect(clientEndpoint, uri).get(5, TimeUnit.SECONDS);
session.sendPartialText("hel", false, Callback.NOOP);
session.sendPartialText("lo ", false, Callback.NOOP);
session.sendPartialText("wor", false, Callback.NOOP);
session.sendPartialText("ld", false, Callback.NOOP);
session.sendPartialText(null, true, Callback.NOOP);
assertThat(serverEndpoint.textMessages.poll(5, TimeUnit.SECONDS), equalTo("hello world"));
assertTrue(serverEndpoint.closeLatch.await(5, TimeUnit.SECONDS));
assertThat(serverEndpoint.closeCode, equalTo(StatusCode.SERVER_ERROR));
assertThat(serverEndpoint.closeReason, containsString("failed from test"));
assertThat(serverEndpoint.error, instanceOf(IllegalStateException.class));
assertTrue(clientEndpoint.closeLatch.await(5, TimeUnit.SECONDS));
assertThat(clientEndpoint.closeCode, equalTo(StatusCode.SERVER_ERROR));
assertThat(clientEndpoint.closeReason, containsString("failed from test"));
assertNull(clientEndpoint.error);
}
@Test
public void testReaderOnErrorPartialRead() throws Exception
{
ReaderErrorEndpoint serverEndpoint = new ReaderErrorEndpoint(5);
_upgradeHandler.getServerWebSocketContainer()
.addMapping("/", (req, resp, cb) -> serverEndpoint);
URI uri = URI.create("ws://localhost:" + _connector.getLocalPort());
EventSocket clientEndpoint = new EventSocket();
Session session = _client.connect(clientEndpoint, uri).get(5, TimeUnit.SECONDS);
session.sendPartialText("hel", false, Callback.NOOP);
session.sendPartialText("lo ", false, Callback.NOOP);
session.sendPartialText("wor", false, Callback.NOOP);
session.sendPartialText("ld", false, Callback.NOOP);
session.sendPartialText(null, true, Callback.NOOP);
assertThat(serverEndpoint.textMessages.poll(5, TimeUnit.SECONDS), equalTo("hello"));
assertTrue(serverEndpoint.closeLatch.await(5, TimeUnit.SECONDS));
assertThat(serverEndpoint.closeCode, equalTo(StatusCode.SERVER_ERROR));
assertThat(serverEndpoint.closeReason, containsString("failed from test"));
assertThat(serverEndpoint.error, instanceOf(IllegalStateException.class));
assertTrue(clientEndpoint.closeLatch.await(5, TimeUnit.SECONDS));
assertThat(clientEndpoint.closeCode, equalTo(StatusCode.SERVER_ERROR));
assertThat(clientEndpoint.closeReason, containsString("failed from test"));
assertNull(clientEndpoint.error);
}
}