diff --git a/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/WebSocketSessionState.java b/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/WebSocketSessionState.java index df1f1ef4149..a8ce47b3679 100644 --- a/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/WebSocketSessionState.java +++ b/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/WebSocketSessionState.java @@ -90,7 +90,7 @@ public class WebSocketSessionState public boolean isInputOpen() { State state = getState(); - return (state == State.OPEN || state == State.OSHUT); + return (state == State.CONNECTED || state == State.OPEN || state == State.OSHUT); } public boolean isOutputOpen() diff --git a/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/ExplicitDemandTest.java b/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/ExplicitDemandTest.java index 49478cb482e..54de2c1b5a4 100644 --- a/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/ExplicitDemandTest.java +++ b/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/ExplicitDemandTest.java @@ -18,17 +18,23 @@ import java.net.URI; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.List; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.ServerConnector; +import org.eclipse.jetty.util.BlockingArrayQueue; +import org.eclipse.jetty.util.BufferUtil; import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.Frame; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.annotations.WebSocket; import org.eclipse.jetty.websocket.client.WebSocketClient; +import org.eclipse.jetty.websocket.core.CloseStatus; +import org.eclipse.jetty.websocket.core.OpCode; import org.eclipse.jetty.websocket.server.WebSocketUpgradeHandler; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -36,6 +42,7 @@ import org.junit.jupiter.api.Test; import static org.awaitility.Awaitility.await; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -74,6 +81,38 @@ public class ExplicitDemandTest } } + @WebSocket(autoDemand = false) + public static class OnOpenSocket implements Session.Listener + { + CountDownLatch onOpen = new CountDownLatch(1); + BlockingQueue textMessages = new BlockingArrayQueue<>(); + Session session; + + @Override + public void onWebSocketOpen(Session session) + { + try + { + this.session = session; + session.demand(); + onOpen.await(); + } + catch (InterruptedException e) + { + throw new RuntimeException(e); + } + } + + @Override + public void onWebSocketFrame(Frame frame, Callback callback) + { + if (frame.getOpCode() == OpCode.TEXT) + textMessages.add(BufferUtil.toString(frame.getPayload())); + callback.succeed(); + + } + } + @WebSocket(autoDemand = false) public static class PingSocket extends ListenerSocket { @@ -99,6 +138,7 @@ public class ExplicitDemandTest private final WebSocketClient client = new WebSocketClient(); private final SuspendSocket serverSocket = new SuspendSocket(); private final ListenerSocket listenerSocket = new ListenerSocket(); + private final OnOpenSocket onOpenSocket = new OnOpenSocket(); private final PingSocket pingSocket = new PingSocket(); private ServerConnector connector; @@ -113,6 +153,7 @@ public class ExplicitDemandTest container.addMapping("/suspend", (rq, rs, cb) -> serverSocket); container.addMapping("/listenerSocket", (rq, rs, cb) -> listenerSocket); container.addMapping("/ping", (rq, rs, cb) -> pingSocket); + container.addMapping("/onOpen", (rq, rs, cb) -> onOpenSocket); }); server.setHandler(wsHandler); @@ -213,4 +254,23 @@ public class ExplicitDemandTest frame = pingSocket.frames.get(2); assertThat(frame.getType(), is(Frame.Type.CLOSE)); } + + @Test + public void testDemandInOnOpen() throws Exception + { + URI uri = new URI("ws://localhost:" + connector.getLocalPort() + "/onOpen"); + EventSocket clientSocket = new EventSocket(); + + Future connect = client.connect(clientSocket, uri); + Session session = connect.get(5, TimeUnit.SECONDS); + session.sendText("test-text", Callback.NOOP); + + String received = onOpenSocket.textMessages.poll(5, TimeUnit.SECONDS); + assertThat(received, equalTo("test-text")); + onOpenSocket.onOpen.countDown(); + + session.close(); + assertTrue(clientSocket.closeLatch.await(5, TimeUnit.SECONDS)); + assertThat(clientSocket.closeCode, equalTo(CloseStatus.NORMAL)); + } }