diff --git a/jetty-websocket/websocket-javax-client/src/main/java/org/eclipse/jetty/websocket/javax/client/internal/JavaxWebSocketClientFrameHandlerFactory.java b/jetty-websocket/websocket-javax-client/src/main/java/org/eclipse/jetty/websocket/javax/client/internal/JavaxWebSocketClientFrameHandlerFactory.java index 06b2b154432..ec33f0e1266 100644 --- a/jetty-websocket/websocket-javax-client/src/main/java/org/eclipse/jetty/websocket/javax/client/internal/JavaxWebSocketClientFrameHandlerFactory.java +++ b/jetty-websocket/websocket-javax-client/src/main/java/org/eclipse/jetty/websocket/javax/client/internal/JavaxWebSocketClientFrameHandlerFactory.java @@ -19,7 +19,6 @@ package org.eclipse.jetty.websocket.javax.client.internal; import javax.websocket.ClientEndpoint; -import javax.websocket.Endpoint; import javax.websocket.EndpointConfig; import org.eclipse.jetty.websocket.javax.common.JavaxWebSocketContainer; @@ -49,7 +48,7 @@ public class JavaxWebSocketClientFrameHandlerFactory extends JavaxWebSocketFrame public JavaxWebSocketFrameHandlerMetadata getMetadata(Class endpointClass, EndpointConfig endpointConfig) { if (javax.websocket.Endpoint.class.isAssignableFrom(endpointClass)) - return createEndpointMetadata((Class)endpointClass, endpointConfig); + return createEndpointMetadata(endpointConfig); if (endpointClass.getAnnotation(ClientEndpoint.class) == null) return null; diff --git a/jetty-websocket/websocket-javax-common/src/main/java/org/eclipse/jetty/websocket/javax/common/JavaxWebSocketFrameHandlerFactory.java b/jetty-websocket/websocket-javax-common/src/main/java/org/eclipse/jetty/websocket/javax/common/JavaxWebSocketFrameHandlerFactory.java index 6087274b386..853dc88d639 100644 --- a/jetty-websocket/websocket-javax-common/src/main/java/org/eclipse/jetty/websocket/javax/common/JavaxWebSocketFrameHandlerFactory.java +++ b/jetty-websocket/websocket-javax-common/src/main/java/org/eclipse/jetty/websocket/javax/common/JavaxWebSocketFrameHandlerFactory.java @@ -250,20 +250,20 @@ public abstract class JavaxWebSocketFrameHandlerFactory } } - protected JavaxWebSocketFrameHandlerMetadata createEndpointMetadata(Class endpointClass, EndpointConfig endpointConfig) + protected JavaxWebSocketFrameHandlerMetadata createEndpointMetadata(EndpointConfig endpointConfig) { JavaxWebSocketFrameHandlerMetadata metadata = new JavaxWebSocketFrameHandlerMetadata(endpointConfig); - MethodHandles.Lookup lookup = getApplicationMethodHandleLookup(endpointClass); + MethodHandles.Lookup lookup = getServerMethodHandleLookup(); - Method openMethod = ReflectUtils.findMethod(endpointClass, "onOpen", Session.class, EndpointConfig.class); + Method openMethod = ReflectUtils.findMethod(Endpoint.class, "onOpen", Session.class, EndpointConfig.class); MethodHandle open = toMethodHandle(lookup, openMethod); metadata.setOpenHandler(open, openMethod); - Method closeMethod = ReflectUtils.findMethod(endpointClass, "onClose", Session.class, CloseReason.class); + Method closeMethod = ReflectUtils.findMethod(Endpoint.class, "onClose", Session.class, CloseReason.class); MethodHandle close = toMethodHandle(lookup, closeMethod); metadata.setCloseHandler(close, closeMethod); - Method errorMethod = ReflectUtils.findMethod(endpointClass, "onError", Session.class, Throwable.class); + Method errorMethod = ReflectUtils.findMethod(Endpoint.class, "onError", Session.class, Throwable.class); MethodHandle error = toMethodHandle(lookup, errorMethod); metadata.setErrorHandler(error, errorMethod); diff --git a/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/DummyFrameHandlerFactory.java b/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/DummyFrameHandlerFactory.java index 12d00d0c67d..5ac2117464a 100644 --- a/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/DummyFrameHandlerFactory.java +++ b/jetty-websocket/websocket-javax-common/src/test/java/org/eclipse/jetty/websocket/javax/common/DummyFrameHandlerFactory.java @@ -20,7 +20,6 @@ package org.eclipse.jetty.websocket.javax.common; import javax.websocket.ClientEndpoint; import javax.websocket.ClientEndpointConfig; -import javax.websocket.Endpoint; import javax.websocket.EndpointConfig; import org.eclipse.jetty.websocket.util.InvokerUtils; @@ -43,7 +42,7 @@ public class DummyFrameHandlerFactory extends JavaxWebSocketFrameHandlerFactory { if (javax.websocket.Endpoint.class.isAssignableFrom(endpointClass)) { - return createEndpointMetadata((Class)endpointClass, endpointConfig); + return createEndpointMetadata(endpointConfig); } if (endpointClass.getAnnotation(ClientEndpoint.class) == null) diff --git a/jetty-websocket/websocket-javax-server/src/main/java/org/eclipse/jetty/websocket/javax/server/internal/JavaxWebSocketServerFrameHandlerFactory.java b/jetty-websocket/websocket-javax-server/src/main/java/org/eclipse/jetty/websocket/javax/server/internal/JavaxWebSocketServerFrameHandlerFactory.java index 4f6c29a38f4..480d5acad5a 100644 --- a/jetty-websocket/websocket-javax-server/src/main/java/org/eclipse/jetty/websocket/javax/server/internal/JavaxWebSocketServerFrameHandlerFactory.java +++ b/jetty-websocket/websocket-javax-server/src/main/java/org/eclipse/jetty/websocket/javax/server/internal/JavaxWebSocketServerFrameHandlerFactory.java @@ -18,7 +18,6 @@ package org.eclipse.jetty.websocket.javax.server.internal; -import javax.websocket.Endpoint; import javax.websocket.EndpointConfig; import javax.websocket.server.ServerEndpoint; @@ -42,7 +41,7 @@ public class JavaxWebSocketServerFrameHandlerFactory extends JavaxWebSocketClien public JavaxWebSocketFrameHandlerMetadata getMetadata(Class endpointClass, EndpointConfig endpointConfig) { if (javax.websocket.Endpoint.class.isAssignableFrom(endpointClass)) - return createEndpointMetadata((Class)endpointClass, endpointConfig); + return createEndpointMetadata(endpointConfig); ServerEndpoint anno = endpointClass.getAnnotation(ServerEndpoint.class); if (anno == null) diff --git a/jetty-websocket/websocket-javax-tests/src/test/java/org/eclipse/jetty/websocket/javax/tests/JavaxOnCloseTest.java b/jetty-websocket/websocket-javax-tests/src/test/java/org/eclipse/jetty/websocket/javax/tests/JavaxOnCloseTest.java index 6b02668b717..65ff74b4587 100644 --- a/jetty-websocket/websocket-javax-tests/src/test/java/org/eclipse/jetty/websocket/javax/tests/JavaxOnCloseTest.java +++ b/jetty-websocket/websocket-javax-tests/src/test/java/org/eclipse/jetty/websocket/javax/tests/JavaxOnCloseTest.java @@ -50,11 +50,11 @@ import static org.junit.jupiter.api.Assertions.assertTrue; public class JavaxOnCloseTest { - private static BlockingArrayQueue serverEndpoints = new BlockingArrayQueue<>(); + private static final BlockingArrayQueue serverEndpoints = new BlockingArrayQueue<>(); private Server server; private ServerConnector connector; - private JavaxWebSocketClientContainer client = new JavaxWebSocketClientContainer(); + private final JavaxWebSocketClientContainer client = new JavaxWebSocketClientContainer(); @ServerEndpoint("/") public static class OnCloseEndpoint extends EventSocket @@ -84,7 +84,7 @@ public class JavaxOnCloseTest @ClientEndpoint public static class BlockingClientEndpoint extends EventSocket { - private CountDownLatch blockInClose = new CountDownLatch(1); + private final CountDownLatch blockInClose = new CountDownLatch(1); public void unBlockClose() { diff --git a/jetty-websocket/websocket-javax-tests/src/test/java/org/eclipse/jetty/websocket/javax/tests/client/EndpointEchoTest.java b/jetty-websocket/websocket-javax-tests/src/test/java/org/eclipse/jetty/websocket/javax/tests/client/EndpointEchoTest.java index 39066ed9167..e62d4c364e6 100644 --- a/jetty-websocket/websocket-javax-tests/src/test/java/org/eclipse/jetty/websocket/javax/tests/client/EndpointEchoTest.java +++ b/jetty-websocket/websocket-javax-tests/src/test/java/org/eclipse/jetty/websocket/javax/tests/client/EndpointEchoTest.java @@ -18,13 +18,18 @@ package org.eclipse.jetty.websocket.javax.tests.client; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import javax.websocket.CloseReason; import javax.websocket.ContainerProvider; +import javax.websocket.Endpoint; import javax.websocket.EndpointConfig; import javax.websocket.MessageHandler; import javax.websocket.Session; import javax.websocket.WebSocketContainer; +import org.eclipse.jetty.util.BlockingArrayQueue; import org.eclipse.jetty.websocket.javax.common.JavaxWebSocketSession; import org.eclipse.jetty.websocket.javax.tests.LocalServer; import org.eclipse.jetty.websocket.javax.tests.WSEndpointTracker; @@ -35,6 +40,7 @@ import org.junit.jupiter.api.Test; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; +import static org.junit.jupiter.api.Assertions.assertTrue; public class EndpointEchoTest { @@ -105,4 +111,45 @@ public class EndpointEchoTest session.close(); endpoint.awaitCloseEvent("Client"); } + + @Test + public void testEchoAnonymousInstance() throws Exception + { + CountDownLatch openLatch = new CountDownLatch(1); + CountDownLatch closeLatch = new CountDownLatch(1); + BlockingQueue textMessages = new BlockingArrayQueue<>(); + Endpoint clientEndpoint = new Endpoint() + { + @Override + public void onOpen(Session session, EndpointConfig config) + { + // Cannot replace this with a lambda or it breaks ReflectUtils.findGenericClassFor(). + session.addMessageHandler(new MessageHandler.Whole() + { + @Override + public void onMessage(String message) + { + textMessages.add(message); + } + }); + openLatch.countDown(); + } + + @Override + public void onClose(Session session, CloseReason closeReason) + { + closeLatch.countDown(); + } + }; + + WebSocketContainer container = ContainerProvider.getWebSocketContainer(); + Session session = container.connectToServer(clientEndpoint, null, server.getWsUri().resolve("/echo/text")); + assertTrue(openLatch.await(5, TimeUnit.SECONDS)); + session.getBasicRemote().sendText("Echo"); + + String resp = textMessages.poll(1, TimeUnit.SECONDS); + assertThat("Response echo", resp, is("Echo")); + session.close(); + assertTrue(closeLatch.await(5, TimeUnit.SECONDS)); + } } diff --git a/jetty-websocket/websocket-jetty-common/src/main/java/org/eclipse/jetty/websocket/common/JettyWebSocketFrameHandlerFactory.java b/jetty-websocket/websocket-jetty-common/src/main/java/org/eclipse/jetty/websocket/common/JettyWebSocketFrameHandlerFactory.java index 0d08b373c4f..da579a24477 100644 --- a/jetty-websocket/websocket-jetty-common/src/main/java/org/eclipse/jetty/websocket/common/JettyWebSocketFrameHandlerFactory.java +++ b/jetty-websocket/websocket-jetty-common/src/main/java/org/eclipse/jetty/websocket/common/JettyWebSocketFrameHandlerFactory.java @@ -195,28 +195,31 @@ public class JettyWebSocketFrameHandlerFactory extends ContainerLifeCycle private JettyWebSocketFrameHandlerMetadata createListenerMetadata(Class endpointClass) { JettyWebSocketFrameHandlerMetadata metadata = new JettyWebSocketFrameHandlerMetadata(); - MethodHandles.Lookup lookup = JettyWebSocketFrameHandlerFactory.getApplicationMethodHandleLookup(endpointClass); + MethodHandles.Lookup lookup = JettyWebSocketFrameHandlerFactory.getServerMethodHandleLookup(); - Method openMethod = ReflectUtils.findMethod(endpointClass, "onWebSocketConnect", Session.class); + if (!WebSocketConnectionListener.class.isAssignableFrom(endpointClass)) + throw new IllegalArgumentException("Class " + endpointClass + " does not implement " + WebSocketConnectionListener.class); + + Method openMethod = ReflectUtils.findMethod(WebSocketConnectionListener.class, "onWebSocketConnect", Session.class); MethodHandle open = toMethodHandle(lookup, openMethod); metadata.setOpenHandler(open, openMethod); - Method closeMethod = ReflectUtils.findMethod(endpointClass, "onWebSocketClose", int.class, String.class); + Method closeMethod = ReflectUtils.findMethod(WebSocketConnectionListener.class, "onWebSocketClose", int.class, String.class); MethodHandle close = toMethodHandle(lookup, closeMethod); metadata.setCloseHandler(close, closeMethod); - Method errorMethod = ReflectUtils.findMethod(endpointClass, "onWebSocketError", Throwable.class); + Method errorMethod = ReflectUtils.findMethod(WebSocketConnectionListener.class, "onWebSocketError", Throwable.class); MethodHandle error = toMethodHandle(lookup, errorMethod); metadata.setErrorHandler(error, errorMethod); // Simple Data Listener if (WebSocketListener.class.isAssignableFrom(endpointClass)) { - Method textMethod = ReflectUtils.findMethod(endpointClass, "onWebSocketText", String.class); + Method textMethod = ReflectUtils.findMethod(WebSocketListener.class, "onWebSocketText", String.class); MethodHandle text = toMethodHandle(lookup, textMethod); metadata.setTextHandler(StringMessageSink.class, text, textMethod); - Method binaryMethod = ReflectUtils.findMethod(endpointClass, "onWebSocketBinary", byte[].class, int.class, int.class); + Method binaryMethod = ReflectUtils.findMethod(WebSocketListener.class, "onWebSocketBinary", byte[].class, int.class, int.class); MethodHandle binary = toMethodHandle(lookup, binaryMethod); metadata.setBinaryHandle(ByteArrayMessageSink.class, binary, binaryMethod); } @@ -224,11 +227,11 @@ public class JettyWebSocketFrameHandlerFactory extends ContainerLifeCycle // Ping/Pong Listener if (WebSocketPingPongListener.class.isAssignableFrom(endpointClass)) { - Method pongMethod = ReflectUtils.findMethod(endpointClass, "onWebSocketPong", ByteBuffer.class); + Method pongMethod = ReflectUtils.findMethod(WebSocketPingPongListener.class, "onWebSocketPong", ByteBuffer.class); MethodHandle pong = toMethodHandle(lookup, pongMethod); metadata.setPongHandle(pong, pongMethod); - Method pingMethod = ReflectUtils.findMethod(endpointClass, "onWebSocketPing", ByteBuffer.class); + Method pingMethod = ReflectUtils.findMethod(WebSocketPingPongListener.class, "onWebSocketPing", ByteBuffer.class); MethodHandle ping = toMethodHandle(lookup, pingMethod); metadata.setPingHandle(ping, pingMethod); } @@ -236,11 +239,11 @@ public class JettyWebSocketFrameHandlerFactory extends ContainerLifeCycle // Partial Data / Message Listener if (WebSocketPartialListener.class.isAssignableFrom(endpointClass)) { - Method textMethod = ReflectUtils.findMethod(endpointClass, "onWebSocketPartialText", String.class, boolean.class); + Method textMethod = ReflectUtils.findMethod(WebSocketPartialListener.class, "onWebSocketPartialText", String.class, boolean.class); MethodHandle text = toMethodHandle(lookup, textMethod); metadata.setTextHandler(PartialStringMessageSink.class, text, textMethod); - Method binaryMethod = ReflectUtils.findMethod(endpointClass, "onWebSocketPartialBinary", ByteBuffer.class, boolean.class); + Method binaryMethod = ReflectUtils.findMethod(WebSocketPartialListener.class, "onWebSocketPartialBinary", ByteBuffer.class, boolean.class); MethodHandle binary = toMethodHandle(lookup, binaryMethod); metadata.setBinaryHandle(PartialByteBufferMessageSink.class, binary, binaryMethod); } @@ -248,7 +251,7 @@ public class JettyWebSocketFrameHandlerFactory extends ContainerLifeCycle // Frame Listener if (WebSocketFrameListener.class.isAssignableFrom(endpointClass)) { - Method frameMethod = ReflectUtils.findMethod(endpointClass, "onWebSocketFrame", Frame.class); + Method frameMethod = ReflectUtils.findMethod(WebSocketFrameListener.class, "onWebSocketFrame", Frame.class); MethodHandle frame = toMethodHandle(lookup, frameMethod); metadata.setFrameHandler(frame, frameMethod); } diff --git a/jetty-websocket/websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/listeners/WebSocketListenerTest.java b/jetty-websocket/websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/listeners/WebSocketListenerTest.java index fe290ccba82..09050619ea2 100644 --- a/jetty-websocket/websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/listeners/WebSocketListenerTest.java +++ b/jetty-websocket/websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/listeners/WebSocketListenerTest.java @@ -21,6 +21,8 @@ package org.eclipse.jetty.websocket.tests.listeners; import java.net.URI; import java.nio.ByteBuffer; import java.util.List; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -28,13 +30,18 @@ import java.util.stream.Stream; import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.ServerConnector; import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.util.BlockingArrayQueue; import org.eclipse.jetty.util.BufferUtil; +import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.StatusCode; +import org.eclipse.jetty.websocket.api.WebSocketListener; import org.eclipse.jetty.websocket.client.WebSocketClient; import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; +import org.eclipse.jetty.websocket.tests.EchoSocket; import org.eclipse.jetty.websocket.tests.EventSocket; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; @@ -60,6 +67,8 @@ public class WebSocketListenerTest contextHandler.setContextPath("/"); JettyWebSocketServletContainerInitializer.configure(contextHandler, (context, container) -> { + container.addMapping("/echo", (req, res) -> new EchoSocket()); + for (Class c : getClassListFromArguments(TextListeners.getTextListeners())) { container.addMapping("/text/" + c.getSimpleName(), (req, res) -> construct(c)); @@ -125,6 +134,47 @@ public class WebSocketListenerTest assertThat(clientEndpoint.closeReason, is("standard close")); } + @Test + public void testAnonymousListener() throws Exception + { + CountDownLatch openLatch = new CountDownLatch(1); + CountDownLatch closeLatch = new CountDownLatch(1); + BlockingQueue textMessages = new BlockingArrayQueue<>(); + WebSocketListener clientEndpoint = new WebSocketListener() + { + @Override + public void onWebSocketConnect(Session session) + { + openLatch.countDown(); + } + + @Override + public void onWebSocketText(String message) + { + textMessages.add(message); + } + + @Override + public void onWebSocketClose(int statusCode, String reason) + { + closeLatch.countDown(); + } + }; + + Session session = client.connect(clientEndpoint, serverUri.resolve("/echo")).get(5, TimeUnit.SECONDS); + assertTrue(openLatch.await(5, TimeUnit.SECONDS)); + + // Send and receive echo on client. + String payload = "hello world"; + session.getRemote().sendString(payload); + String echoMessage = textMessages.poll(5, TimeUnit.SECONDS); + assertThat(echoMessage, is(payload)); + + // Close normally. + session.close(StatusCode.NORMAL, "standard close"); + assertTrue(closeLatch.await(5, TimeUnit.SECONDS)); + } + private List> getClassListFromArguments(Stream stream) { return stream.map(arguments -> (Class)arguments.get()[0]).collect(Collectors.toList());