Issue #6965 - support programmatic upgrade for javax/jakarta websockets

Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
This commit is contained in:
Lachlan Roberts 2021-10-07 14:22:36 +11:00
parent ca8d147ec4
commit 5811b042b1
6 changed files with 237 additions and 17 deletions

View File

@ -24,7 +24,6 @@ import javax.websocket.Extension;
import javax.websocket.Extension.Parameter;
import javax.websocket.server.ServerEndpointConfig;
import org.eclipse.jetty.http.pathmap.PathSpec;
import org.eclipse.jetty.http.pathmap.UriTemplatePathSpec;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.websocket.core.ExtensionConfig;
@ -139,14 +138,24 @@ public class JavaxWebSocketCreator implements WebSocketCreator
resp.setExtensions(configs);
// [JSR] Step 4: build out new ServerEndpointConfig
PathSpec pathSpec = jsrHandshakeRequest.getRequestPathSpec();
if (pathSpec instanceof UriTemplatePathSpec)
Object pathSpecObject = jsrHandshakeRequest.getRequestPathSpec();
if (pathSpecObject instanceof UriTemplatePathSpec)
{
// We have a PathParam path spec
UriTemplatePathSpec wspathSpec = (UriTemplatePathSpec)pathSpec;
String requestPath = req.getRequestPath();
// Wrap the config with the path spec information
config = new PathParamServerEndpointConfig(config, wspathSpec, requestPath);
// We can get path params from PathSpec and Request Path.
UriTemplatePathSpec pathSpec = (UriTemplatePathSpec)pathSpecObject;
Map<String, String> pathParams = pathSpec.getPathParams(req.getRequestPath());
// Wrap the config with the path spec information.
config = new PathParamServerEndpointConfig(config, pathParams);
}
else
{
Map<String, String> pathParams = jsrHandshakeRequest.getPathParams();
if (pathParams != null)
{
// Wrap the config with the path spec information.
config = new PathParamServerEndpointConfig(config, pathParams);
}
}
// [JSR] Step 5: Call modifyHandshake

View File

@ -13,12 +13,16 @@
package org.eclipse.jetty.websocket.javax.server.internal;
import java.io.IOException;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.function.Function;
import javax.servlet.ServletContext;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.websocket.DeploymentException;
import javax.websocket.server.ServerEndpoint;
import javax.websocket.server.ServerEndpointConfig;
@ -34,7 +38,9 @@ import org.eclipse.jetty.websocket.core.WebSocketComponents;
import org.eclipse.jetty.websocket.core.client.WebSocketCoreClient;
import org.eclipse.jetty.websocket.core.exception.InvalidSignatureException;
import org.eclipse.jetty.websocket.core.internal.util.ReflectUtils;
import org.eclipse.jetty.websocket.core.server.Handshaker;
import org.eclipse.jetty.websocket.core.server.WebSocketMappings;
import org.eclipse.jetty.websocket.core.server.WebSocketNegotiator;
import org.eclipse.jetty.websocket.core.server.WebSocketServerComponents;
import org.eclipse.jetty.websocket.javax.client.internal.JavaxWebSocketClientContainer;
import org.eclipse.jetty.websocket.javax.server.config.ContainerDefaultConfigurator;
@ -46,6 +52,7 @@ import org.slf4j.LoggerFactory;
public class JavaxWebSocketServerContainer extends JavaxWebSocketClientContainer implements javax.websocket.server.ServerContainer, LifeCycle.Listener
{
public static final String JAVAX_WEBSOCKET_CONTAINER_ATTRIBUTE = javax.websocket.server.ServerContainer.class.getName();
public static final String PATH_PARAM_ATTRIBUTE = "javax.websocket.server.pathParams";
private static final Logger LOG = LoggerFactory.getLogger(JavaxWebSocketServerContainer.class);
public static JavaxWebSocketServerContainer getContainer(ServletContext servletContext)
@ -254,6 +261,35 @@ public class JavaxWebSocketServerContainer extends JavaxWebSocketClientContainer
}
}
public void upgradeHttpToWebSocket(Object httpServletRequest, Object httpServletResponse, ServerEndpointConfig sec,
Map<String, String> pathParameters) throws IOException, DeploymentException
{
HttpServletRequest request = (HttpServletRequest)httpServletRequest;
HttpServletResponse response = (HttpServletResponse)httpServletResponse;
// Decorate the provided Configurator.
components.getObjectFactory().decorate(sec.getConfigurator());
// If we have annotations merge the annotated ServerEndpointConfig with the provided one.
Class<?> endpointClass = sec.getEndpointClass();
ServerEndpoint anno = endpointClass.getAnnotation(ServerEndpoint.class);
ServerEndpointConfig config = (anno == null) ? sec
: new AnnotatedServerEndpointConfig(this, endpointClass, anno, sec);
if (LOG.isDebugEnabled())
LOG.debug("addEndpoint({}) path={} endpoint={}", config, config.getPath(), endpointClass);
validateEndpointConfig(config);
frameHandlerFactory.getMetadata(config.getEndpointClass(), config);
request.setAttribute(JavaxWebSocketServerContainer.PATH_PARAM_ATTRIBUTE, pathParameters);
// Perform the upgrade.
JavaxWebSocketCreator creator = new JavaxWebSocketCreator(this, config, getExtensionRegistry());
WebSocketNegotiator negotiator = WebSocketNegotiator.from(creator, frameHandlerFactory);
Handshaker handshaker = webSocketMappings.getHandshaker();
handshaker.upgradeRequest(negotiator, request, response, components, defaultCustomizer);
}
@Override
protected void doStart() throws Exception
{

View File

@ -60,6 +60,12 @@ public class JsrHandshakeRequest implements HandshakeRequest
return (PathSpec)delegate.getServletAttribute(PathSpec.class.getName());
}
@SuppressWarnings("unchecked")
public Map<String, String> getPathParams()
{
return (Map<String, String>)delegate.getServletAttribute(JavaxWebSocketServerContainer.PATH_PARAM_ATTRIBUTE);
}
@Override
public URI getRequestURI()
{

View File

@ -17,7 +17,6 @@ import java.util.HashMap;
import java.util.Map;
import javax.websocket.server.ServerEndpointConfig;
import org.eclipse.jetty.http.pathmap.UriTemplatePathSpec;
import org.eclipse.jetty.util.URIUtil;
import org.eclipse.jetty.websocket.javax.common.PathParamProvider;
import org.eclipse.jetty.websocket.javax.common.ServerEndpointConfigWrapper;
@ -30,18 +29,13 @@ public class PathParamServerEndpointConfig extends ServerEndpointConfigWrapper i
{
private final Map<String, String> pathParamMap;
public PathParamServerEndpointConfig(ServerEndpointConfig config, UriTemplatePathSpec pathSpec, String requestPath)
public PathParamServerEndpointConfig(ServerEndpointConfig config, Map<String, String> pathParams)
{
super(config);
Map<String, String> pathMap = pathSpec.getPathParams(requestPath);
pathParamMap = new HashMap<>();
if (pathMap != null)
{
pathMap.entrySet().stream().forEach(
entry -> pathParamMap.put(entry.getKey(), URIUtil.decodePath(entry.getValue()))
);
}
if (pathParams != null)
pathParams.forEach((key, value) -> pathParamMap.put(key, URIUtil.decodePath(value)));
}
@Override

View File

@ -35,6 +35,11 @@
<groupId>org.eclipse.jetty.toolchain</groupId>
<artifactId>jetty-javax-websocket-api</artifactId>
</dependency>
<dependency>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-util-ajax</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-slf4j-impl</artifactId>

View File

@ -0,0 +1,170 @@
//
// ========================================================================
// Copyright (c) 1995-2021 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License v. 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//
package org.eclipse.jetty.websocket.javax.tests;
import java.io.IOException;
import java.net.URI;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.websocket.CloseReason;
import javax.websocket.DeploymentException;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.Session;
import javax.websocket.server.ServerEndpointConfig;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.servlet.ServletHolder;
import org.eclipse.jetty.util.ajax.JSON;
import org.eclipse.jetty.websocket.javax.client.internal.JavaxWebSocketClientContainer;
import org.eclipse.jetty.websocket.javax.server.config.JavaxWebSocketServletContainerInitializer;
import org.eclipse.jetty.websocket.javax.server.internal.JavaxWebSocketServerContainer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class ProgrammaticWebSocketUpgradeTest
{
private static final Map<String, String> PATH_PARAMS = Map.of("param1", "value1", "param2", "value2");
private static final JSON JSON = new JSON();
private Server server;
private ServerConnector connector;
private JavaxWebSocketClientContainer client;
@BeforeEach
public void before() throws Exception
{
client = new JavaxWebSocketClientContainer();
server = new Server();
connector = new ServerConnector(server);
server.addConnector(connector);
ServletContextHandler contextHandler = new ServletContextHandler(ServletContextHandler.SESSIONS);
contextHandler.setContextPath("/");
contextHandler.addServlet(new ServletHolder(new CustomUpgradeServlet()), "/");
server.setHandler(contextHandler);
JavaxWebSocketServletContainerInitializer.configure(contextHandler, null);
server.start();
client.start();
}
@AfterEach
public void stop() throws Exception
{
client.stop();
server.stop();
}
public static class PathParamsEndpoint extends Endpoint
{
@Override
public void onOpen(Session session, EndpointConfig config)
{
try
{
session.getBasicRemote().sendText(JSON.toJSON(session.getPathParameters()));
session.close();
}
catch (IOException e)
{
throw new RuntimeException(e);
}
}
}
public static class CustomUpgradeServlet extends HttpServlet
{
private JavaxWebSocketServerContainer container;
@Override
public void init(ServletConfig config) throws ServletException
{
super.init(config);
container = JavaxWebSocketServerContainer.getContainer(getServletContext());
}
@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
{
try
{
switch (request.getServletPath())
{
case "/echo":
{
ServerEndpointConfig sec = ServerEndpointConfig.Builder.create(EchoSocket.class, "/").build();
HashMap<String, String> pathParams = new HashMap<>();
container.upgradeHttpToWebSocket(request, response, sec, pathParams);
break;
}
case "/pathParams":
{
ServerEndpointConfig sec = ServerEndpointConfig.Builder.create(PathParamsEndpoint.class, "/").build();
container.upgradeHttpToWebSocket(request, response, sec, PATH_PARAMS);
break;
}
default:
throw new IllegalStateException();
}
}
catch (DeploymentException e)
{
throw new ServletException(e);
}
}
}
@Test
public void testWebSocketUpgrade() throws Exception
{
URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/echo");
EventSocket socket = new EventSocket();
try (Session session = client.connectToServer(socket, uri))
{
session.getBasicRemote().sendText("hello world");
}
assertTrue(socket.closeLatch.await(5, TimeUnit.SECONDS));
String msg = socket.textMessages.poll();
assertThat(msg, is("hello world"));
assertThat(socket.closeReason.getCloseCode(), is(CloseReason.CloseCodes.NORMAL_CLOSURE));
}
@Test
public void testPathParameters() throws Exception
{
URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/pathParams");
EventSocket socket = new EventSocket();
client.connectToServer(socket, uri);
assertTrue(socket.closeLatch.await(5, TimeUnit.SECONDS));
String msg = socket.textMessages.poll();
assertThat(JSON.fromJSON(msg), is(PATH_PARAMS));
assertThat(socket.closeReason.getCloseCode(), is(CloseReason.CloseCodes.NORMAL_CLOSURE));
}
}