From 4251a3e092868259b6439aa286f25b8b3d9f2e75 Mon Sep 17 00:00:00 2001 From: Jan Bartel Date: Wed, 28 Aug 2019 12:29:14 +1000 Subject: [PATCH] Issue #4022 Prevent Servlet adding another Servlet (#4024) * Issue #4022 Prevent Servlet adding Servlet and added unit tests. Signed-off-by: Jan Bartel --- .../jetty/servlet/ServletContextHandler.java | 17 +- .../eclipse/jetty/servlet/ServletHandler.java | 12 + .../servlet/ServletContextHandlerTest.java | 378 +++++++++++++++++- .../jetty/servlet/ServletHolderTest.java | 47 +++ 4 files changed, 444 insertions(+), 10 deletions(-) diff --git a/jetty-servlet/src/main/java/org/eclipse/jetty/servlet/ServletContextHandler.java b/jetty-servlet/src/main/java/org/eclipse/jetty/servlet/ServletContextHandler.java index ba98aee7515..db28cfdbdb7 100644 --- a/jetty-servlet/src/main/java/org/eclipse/jetty/servlet/ServletContextHandler.java +++ b/jetty-servlet/src/main/java/org/eclipse/jetty/servlet/ServletContextHandler.java @@ -1067,10 +1067,13 @@ public class ServletContextHandler extends ContextHandler return new Dispatcher(context, name); } - private void checkDynamicName(String name) + private void checkDynamic(String name) { if (isStarted()) throw new IllegalStateException(); + + if (ServletContextHandler.this.getServletHandler().isInitialized()) + throw new IllegalStateException(); if (StringUtil.isBlank(name)) throw new IllegalStateException("Missing name"); @@ -1085,7 +1088,7 @@ public class ServletContextHandler extends ContextHandler @Override public FilterRegistration.Dynamic addFilter(String filterName, Class filterClass) { - checkDynamicName(filterName); + checkDynamic(filterName); final ServletHandler handler = ServletContextHandler.this.getServletHandler(); FilterHolder holder = handler.getFilter(filterName); @@ -1114,7 +1117,7 @@ public class ServletContextHandler extends ContextHandler @Override public FilterRegistration.Dynamic addFilter(String filterName, String className) { - checkDynamicName(filterName); + checkDynamic(filterName); final ServletHandler handler = ServletContextHandler.this.getServletHandler(); FilterHolder holder = handler.getFilter(filterName); @@ -1143,7 +1146,7 @@ public class ServletContextHandler extends ContextHandler @Override public FilterRegistration.Dynamic addFilter(String filterName, Filter filter) { - checkDynamicName(filterName); + checkDynamic(filterName); final ServletHandler handler = ServletContextHandler.this.getServletHandler(); FilterHolder holder = handler.getFilter(filterName); @@ -1173,7 +1176,7 @@ public class ServletContextHandler extends ContextHandler @Override public ServletRegistration.Dynamic addServlet(String servletName, Class servletClass) { - checkDynamicName(servletName); + checkDynamic(servletName); final ServletHandler handler = ServletContextHandler.this.getServletHandler(); ServletHolder holder = handler.getServlet(servletName); @@ -1203,7 +1206,7 @@ public class ServletContextHandler extends ContextHandler @Override public ServletRegistration.Dynamic addServlet(String servletName, String className) { - checkDynamicName(servletName); + checkDynamic(servletName); final ServletHandler handler = ServletContextHandler.this.getServletHandler(); ServletHolder holder = handler.getServlet(servletName); @@ -1233,7 +1236,7 @@ public class ServletContextHandler extends ContextHandler @Override public ServletRegistration.Dynamic addServlet(String servletName, Servlet servlet) { - checkDynamicName(servletName); + checkDynamic(servletName); final ServletHandler handler = ServletContextHandler.this.getServletHandler(); ServletHolder holder = handler.getServlet(servletName); diff --git a/jetty-servlet/src/main/java/org/eclipse/jetty/servlet/ServletHandler.java b/jetty-servlet/src/main/java/org/eclipse/jetty/servlet/ServletHandler.java index b730ed82b52..897a9ccba98 100644 --- a/jetty-servlet/src/main/java/org/eclipse/jetty/servlet/ServletHandler.java +++ b/jetty-servlet/src/main/java/org/eclipse/jetty/servlet/ServletHandler.java @@ -116,6 +116,7 @@ public class ServletHandler extends ScopedHandler private PathMappings _servletPathMap; private ListenerHolder[] _listeners = new ListenerHolder[0]; + private boolean _initialized = false; @SuppressWarnings("unchecked") protected final ConcurrentMap[] _chainCache = new ConcurrentMap[FilterMapping.ALL]; @@ -331,6 +332,7 @@ public class ServletHandler extends ScopedHandler _filterPathMappings = null; _filterNameMappings = null; _servletPathMap = null; + _initialized = false; } protected IdentityService getIdentityService() @@ -730,6 +732,8 @@ public class ServletHandler extends ScopedHandler public void initialize() throws Exception { + _initialized = true; + MultiException mx = new MultiException(); Stream.concat(Stream.concat( @@ -755,6 +759,14 @@ public class ServletHandler extends ScopedHandler mx.ifExceptionThrow(); } + + /** + * @return true if initialized has been called, false otherwise + */ + public boolean isInitialized() + { + return _initialized; + } /** * @return whether the filter chains are cached. diff --git a/jetty-servlet/src/test/java/org/eclipse/jetty/servlet/ServletContextHandlerTest.java b/jetty-servlet/src/test/java/org/eclipse/jetty/servlet/ServletContextHandlerTest.java index c870a679b94..14d7ccfa616 100644 --- a/jetty-servlet/src/test/java/org/eclipse/jetty/servlet/ServletContextHandlerTest.java +++ b/jetty-servlet/src/test/java/org/eclipse/jetty/servlet/ServletContextHandlerTest.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.io.PrintWriter; import java.util.ArrayList; import java.util.Collections; +import java.util.EnumSet; import java.util.EventListener; import java.util.List; import java.util.Objects; @@ -29,6 +30,11 @@ import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import javax.servlet.DispatcherType; +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; +import javax.servlet.FilterRegistration; import javax.servlet.Servlet; import javax.servlet.ServletContainerInitializer; import javax.servlet.ServletContext; @@ -37,10 +43,13 @@ import javax.servlet.ServletContextAttributeListener; import javax.servlet.ServletContextEvent; import javax.servlet.ServletContextListener; import javax.servlet.ServletException; +import javax.servlet.ServletRegistration; +import javax.servlet.ServletRequest; import javax.servlet.ServletRequestAttributeEvent; import javax.servlet.ServletRequestAttributeListener; import javax.servlet.ServletRequestEvent; import javax.servlet.ServletRequestListener; +import javax.servlet.ServletResponse; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -71,6 +80,9 @@ import org.eclipse.jetty.server.session.SessionHandler; import org.eclipse.jetty.util.DecoratedObjectFactory; import org.eclipse.jetty.util.Decorator; import org.eclipse.jetty.util.component.AbstractLifeCycle; +import org.eclipse.jetty.util.log.Log; +import org.eclipse.jetty.util.log.Logger; +import org.eclipse.jetty.util.log.StacklessLogging; import org.hamcrest.Matchers; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -110,12 +122,13 @@ public class ServletContextHandlerTest public static class MySCIStarter extends AbstractLifeCycle implements ServletContextHandler.ServletContainerInitializerCaller { - MySCI _sci = new MySCI(); + ServletContainerInitializer _sci = null; ContextHandler.Context _ctx; - MySCIStarter(ContextHandler.Context ctx) + MySCIStarter(ContextHandler.Context ctx, ServletContainerInitializer sci) { _ctx = ctx; + _sci = sci; } @Override @@ -417,7 +430,7 @@ public class ServletContextHandlerTest _server.setHandler(contexts); ServletContextHandler root = new ServletContextHandler(contexts, "/"); - root.addBean(new MySCIStarter(root.getServletContext()), true); + root.addBean(new MySCIStarter(root.getServletContext(), new MySCI()), true); _server.start(); assertTrue((Boolean)root.getServletContext().getAttribute("MySCI.startup")); assertTrue((Boolean)root.getServletContext().getAttribute("MyContextListener.contextInitialized")); @@ -595,6 +608,180 @@ public class ServletContextHandlerTest assertEquals(0, __testServlets.get()); } + @Test + public void testAddServletFromServlet() throws Exception + { + //A servlet cannot be added by another servlet + Logger logger = Log.getLogger(ContextHandler.class.getName() + "ROOT"); + + try (StacklessLogging stackless = new StacklessLogging(logger)) + { + ServletContextHandler context = new ServletContextHandler(); + context.setLogger(logger); + ServletHolder holder = context.addServlet(ServletAddingServlet.class, "/start"); + context.getServletHandler().setStartWithUnavailable(false); + holder.setInitOrder(0); + context.setContextPath("/"); + _server.setHandler(context); + _server.start(); + fail("Servlet can only be added from SCI or SCL"); + } + catch (Exception e) + { + if (e instanceof ServletException) + { + assertTrue(e.getCause() instanceof IllegalStateException); + } + else + fail(e); + } + } + + @Test + public void testAddFilterFromServlet() throws Exception + { + //A filter cannot be added from a servlet + Logger logger = Log.getLogger(ContextHandler.class.getName() + "ROOT"); + + try (StacklessLogging stackless = new StacklessLogging(logger)) + { + ServletContextHandler context = new ServletContextHandler(); + context.setLogger(logger); + ServletHolder holder = context.addServlet(FilterAddingServlet.class, "/filter"); + context.getServletHandler().setStartWithUnavailable(false); + holder.setInitOrder(0); + context.setContextPath("/"); + _server.setHandler(context); + _server.start(); + fail("Filter can only be added from SCI or SCL"); + } + catch (Exception e) + { + if (e instanceof ServletException) + { + assertTrue(e.getCause() instanceof IllegalStateException); + } + else + fail(e); + } + } + + @Test + public void testAddServletFromFilter() throws Exception + { + //A servlet cannot be added from a Filter + Logger logger = Log.getLogger(ContextHandler.class.getName() + "ROOT"); + + try (StacklessLogging stackless = new StacklessLogging(logger)) + { + ServletContextHandler context = new ServletContextHandler(); + context.setLogger(logger); + FilterHolder holder = new FilterHolder(new Filter() + { + @Override + public void init(FilterConfig filterConfig) throws ServletException + { + ServletRegistration rego = filterConfig.getServletContext().addServlet("hello", HelloServlet.class); + rego.addMapping("/hello/*"); + } + + @Override + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) + throws IOException, ServletException + { + } + + @Override + public void destroy() + { + } + + }); + context.addFilter(holder, "/*", EnumSet.of(DispatcherType.REQUEST)); + context.getServletHandler().setStartWithUnavailable(false); + context.setContextPath("/"); + _server.setHandler(context); + _server.start(); + fail("Servlet can only be added from SCI or SCL"); + } + catch (Exception e) + { + if (!(e instanceof IllegalStateException)) + { + if (e instanceof ServletException) + { + assertTrue(e.getCause() instanceof IllegalStateException); + } + else + fail(e); + } + } + } + + @Test + public void testAddServletFromSCL() throws Exception + { + //A servlet can be added from a ServletContextListener + ServletContextHandler context = new ServletContextHandler(); + context.getServletHandler().setStartWithUnavailable(false); + context.setContextPath("/"); + context.addEventListener(new ServletContextListener() + { + + @Override + public void contextInitialized(ServletContextEvent sce) + { + ServletRegistration rego = sce.getServletContext().addServlet("hello", HelloServlet.class); + rego.addMapping("/hello/*"); + } + + @Override + public void contextDestroyed(ServletContextEvent sce) + { + } + + }); + _server.setHandler(context); + _server.start(); + + StringBuffer request = new StringBuffer(); + request.append("GET /hello HTTP/1.0\n"); + request.append("Host: localhost\n"); + request.append("\n"); + + String response = _connector.getResponse(request.toString()); + assertThat("Response", response, containsString("Hello World")); + } + + @Test + public void testAddServletFromSCI() throws Exception + { + //A servlet can be added from a ServletContainerInitializer + ContextHandlerCollection contexts = new ContextHandlerCollection(); + _server.setHandler(contexts); + + ServletContextHandler root = new ServletContextHandler(contexts, "/"); + class ServletAddingSCI implements ServletContainerInitializer + { + @Override + public void onStartup(Set> c, ServletContext ctx) throws ServletException + { + ServletRegistration rego = ctx.addServlet("hello", HelloServlet.class); + rego.addMapping("/hello/*"); + } + } + root.addBean(new MySCIStarter(root.getServletContext(), new ServletAddingSCI()), true); + _server.start(); + + StringBuffer request = new StringBuffer(); + request.append("GET /hello HTTP/1.0\n"); + request.append("Host: localhost\n"); + request.append("\n"); + + String response = _connector.getResponse(request.toString()); + assertThat("Response", response, containsString("Hello World")); + } + @Test public void testAddServletAfterStart() throws Exception { @@ -623,7 +810,126 @@ public class ServletContextHandlerTest response = _connector.getResponse(request.toString()); assertThat("Response", response, containsString("Hello World")); } + + @Test + public void testServletRegistrationByClass() throws Exception + { + ServletContextHandler context = new ServletContextHandler(); + context.setContextPath("/"); + ServletRegistration reg = context.getServletContext().addServlet("test", TestServlet.class); + reg.addMapping("/test"); + _server.setHandler(context); + _server.start(); + + StringBuffer request = new StringBuffer(); + request.append("GET /test HTTP/1.0\n"); + request.append("Host: localhost\n"); + request.append("\n"); + + String response = _connector.getResponse(request.toString()); + assertThat("Response", response, containsString("Test")); + } + + @Test + public void testServletRegistrationByClassName() throws Exception + { + ServletContextHandler context = new ServletContextHandler(); + context.setContextPath("/"); + ServletRegistration reg = context.getServletContext().addServlet("test", TestServlet.class.getName()); + reg.addMapping("/test"); + + _server.setHandler(context); + _server.start(); + + StringBuffer request = new StringBuffer(); + request.append("GET /test HTTP/1.0\n"); + request.append("Host: localhost\n"); + request.append("\n"); + + String response = _connector.getResponse(request.toString()); + assertThat("Response", response, containsString("Test")); + } + + @Test + public void testPartialServletRegistrationByName() throws Exception + { + ServletContextHandler context = new ServletContextHandler(); + context.setContextPath("/"); + ServletHolder partial = new ServletHolder(); + partial.setName("test"); + context.addServlet(partial, "/test"); + + //complete partial servlet registration by providing name of the servlet class + ServletRegistration reg = context.getServletContext().addServlet("test", TestServlet.class.getName()); + assertNotNull(reg); + assertEquals(TestServlet.class.getName(), partial.getClassName()); + + _server.setHandler(context); + _server.start(); + + StringBuffer request = new StringBuffer(); + request.append("GET /test HTTP/1.0\n"); + request.append("Host: localhost\n"); + request.append("\n"); + + String response = _connector.getResponse(request.toString()); + assertThat("Response", response, containsString("Test")); + } + + @Test + public void testPartialServletRegistrationByClass() throws Exception + { + ServletContextHandler context = new ServletContextHandler(); + context.setContextPath("/"); + ServletHolder partial = new ServletHolder(); + partial.setName("test"); + context.addServlet(partial, "/test"); + + //complete partial servlet registration by providing the servlet class + ServletRegistration reg = context.getServletContext().addServlet("test", TestServlet.class); + assertNotNull(reg); + assertEquals(TestServlet.class.getName(), partial.getClassName()); + assertSame(TestServlet.class, partial.getHeldClass()); + + _server.setHandler(context); + _server.start(); + + StringBuffer request = new StringBuffer(); + request.append("GET /test HTTP/1.0\n"); + request.append("Host: localhost\n"); + request.append("\n"); + + String response = _connector.getResponse(request.toString()); + assertThat("Response", response, containsString("Test")); + } + + @Test + public void testNullServletRegistration() throws Exception + { + ServletContextHandler context = new ServletContextHandler(); + context.setContextPath("/"); + ServletHolder full = new ServletHolder(); + full.setName("test"); + full.setHeldClass(TestServlet.class); + context.addServlet(full, "/test"); + + //Must return null if the servlet has been fully defined previously + ServletRegistration reg = context.getServletContext().addServlet("test", TestServlet.class); + assertNull(reg); + + _server.setHandler(context); + _server.start(); + + StringBuffer request = new StringBuffer(); + request.append("GET /test HTTP/1.0\n"); + request.append("Host: localhost\n"); + request.append("\n"); + + String response = _connector.getResponse(request.toString()); + assertThat("Response", response, containsString("Test")); + } + @Test public void testHandlerBeforeServletHandler() throws Exception { @@ -948,6 +1254,28 @@ public class ServletContextHandlerTest writer.write("Hello World"); } } + + public static class MyFilter implements Filter + { + + @Override + public void init(FilterConfig filterConfig) throws ServletException + { + } + + @Override + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, + ServletException + { + request.getServletContext().setAttribute("filter", "filter"); + chain.doFilter(request, response); + } + + @Override + public void destroy() + { + } + } public static class DummyUtilDecorator implements org.eclipse.jetty.util.Decorator { @@ -1009,6 +1337,50 @@ public class ServletContextHandlerTest } } + public static class ServletAddingServlet extends HttpServlet + { + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException + { + resp.getWriter().write("Start"); + resp.getWriter().close(); + } + + @Override + public void init() throws ServletException + { + ServletRegistration dynamic = getServletContext().addServlet("added", AddedServlet.class); + dynamic.addMapping("/added/*"); + } + } + + public static class FilterAddingServlet extends HttpServlet + { + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException + { + resp.getWriter().write("Filter"); + resp.getWriter().close(); + } + + @Override + public void init() throws ServletException + { + FilterRegistration dynamic = getServletContext().addFilter("filter", new MyFilter()); + dynamic.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); + } + } + + public static class AddedServlet extends HttpServlet + { + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException + { + resp.getWriter().write("Added"); + resp.getWriter().close(); + } + } + public static class TestServlet extends HttpServlet { private static final long serialVersionUID = 1L; diff --git a/jetty-servlet/src/test/java/org/eclipse/jetty/servlet/ServletHolderTest.java b/jetty-servlet/src/test/java/org/eclipse/jetty/servlet/ServletHolderTest.java index 06a4d4c3933..f9553980962 100644 --- a/jetty-servlet/src/test/java/org/eclipse/jetty/servlet/ServletHolderTest.java +++ b/jetty-servlet/src/test/java/org/eclipse/jetty/servlet/ServletHolderTest.java @@ -18,7 +18,11 @@ package org.eclipse.jetty.servlet; +import javax.servlet.ServletException; import javax.servlet.UnavailableException; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; import org.eclipse.jetty.server.handler.ContextHandler; import org.eclipse.jetty.util.MultiException; @@ -31,8 +35,15 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import java.io.IOException; + public class ServletHolderTest { + + public static class FakeServlet extends HttpServlet + { + } + @Test public void testTransitiveCompareTo() throws Exception @@ -113,6 +124,42 @@ public class ServletHolderTest assertThat(e.getCause().getMessage(), containsString("foo")); } } + + @Test + public void testWithClass() throws Exception + { + //Test adding servlet by class + try (StacklessLogging stackless = new StacklessLogging(BaseHolder.class, ServletHandler.class, ContextHandler.class, ServletContextHandler.class)) + { + ServletContextHandler context = new ServletContextHandler(); + ServletHandler handler = context.getServletHandler(); + ServletHolder holder = new ServletHolder(); + holder.setName("foo"); + holder.setHeldClass(FakeServlet.class); + handler.addServlet(holder); + handler.start(); + assertTrue(holder.isAvailable()); + assertTrue(holder.isStarted()); + } + } + + @Test + public void testWithClassName() throws Exception + { + //Test adding servlet by classname + try (StacklessLogging stackless = new StacklessLogging(BaseHolder.class, ServletHandler.class, ContextHandler.class, ServletContextHandler.class)) + { + ServletContextHandler context = new ServletContextHandler(); + ServletHandler handler = context.getServletHandler(); + ServletHolder holder = new ServletHolder(); + holder.setName("foo"); + holder.setClassName("org.eclipse.jetty.servlet.ServletHolderTest$FakeServlet"); + handler.addServlet(holder); + handler.start(); + assertTrue(holder.isAvailable()); + assertTrue(holder.isStarted()); + } + } @Test public void testUnloadableClassName() throws Exception