diff --git a/jetty-servlet/src/main/java/org/eclipse/jetty/servlet/ListenerHolder.java b/jetty-servlet/src/main/java/org/eclipse/jetty/servlet/ListenerHolder.java index a8549817e40..14aafaf9107 100644 --- a/jetty-servlet/src/main/java/org/eclipse/jetty/servlet/ListenerHolder.java +++ b/jetty-servlet/src/main/java/org/eclipse/jetty/servlet/ListenerHolder.java @@ -80,30 +80,22 @@ public class ListenerHolder extends BaseHolder throw new IllegalStateException(msg); } - ContextHandler contextHandler = ContextHandler.getCurrentContext().getContextHandler(); - if (contextHandler != null) + ContextHandler contextHandler = null; + if (getServletHandler() != null) + contextHandler = getServletHandler().getServletContextHandler(); + if (contextHandler == null && ContextHandler.getCurrentContext() != null) + contextHandler = ContextHandler.getCurrentContext().getContextHandler(); + if (contextHandler == null) + throw new IllegalStateException("No Context"); + + _listener = getInstance(); + if (_listener == null) { - _listener = getInstance(); - if (_listener == null) - { - //create an instance of the listener and decorate it - try - { - _listener = createInstance(); - } - catch (ServletException ex) - { - Throwable cause = ex.getRootCause(); - if (cause instanceof InstantiationException) - throw (InstantiationException)cause; - if (cause instanceof IllegalAccessException) - throw (IllegalAccessException)cause; - throw ex; - } - } + //create an instance of the listener and decorate it + _listener = createInstance(); _listener = wrap(_listener, WrapFunction.class, WrapFunction::wrapEventListener); - contextHandler.addEventListener(_listener); } + contextHandler.addEventListener(_listener); } @Override diff --git a/jetty-servlet/src/main/java/org/eclipse/jetty/servlet/ServletHandler.java b/jetty-servlet/src/main/java/org/eclipse/jetty/servlet/ServletHandler.java index fb4b8b4e28b..7cffa8b5dac 100644 --- a/jetty-servlet/src/main/java/org/eclipse/jetty/servlet/ServletHandler.java +++ b/jetty-servlet/src/main/java/org/eclipse/jetty/servlet/ServletHandler.java @@ -741,6 +741,29 @@ public class ServletHandler extends ScopedHandler return _initialized; } + protected void initializeHolders(BaseHolder[] holders) + { + for (BaseHolder holder : holders) + { + holder.setServletHandler(this); + if (isInitialized()) + { + try + { + if (!holder.isStarted()) + { + holder.start(); + holder.initialize(); + } + } + catch (Exception e) + { + throw new RuntimeException(e); + } + } + } + } + /** * @return whether the filter chains are cached. */ @@ -768,10 +791,7 @@ public class ServletHandler extends ScopedHandler public void setListeners(ListenerHolder[] listeners) { if (listeners != null) - for (ListenerHolder holder : listeners) - { - holder.setServletHandler(this); - } + initializeHolders(listeners); updateBeans(_listeners,listeners); _listeners = listeners; } @@ -833,9 +853,6 @@ public class ServletHandler extends ScopedHandler { Objects.requireNonNull(servlet); ServletHolder[] holders = getServlets(); - if (holders != null) - holders = holders.clone(); - try { try (AutoLock l = lock()) @@ -947,8 +964,6 @@ public class ServletHandler extends ScopedHandler { Objects.requireNonNull(holder); FilterHolder[] holders = getFilters(); - if (holders != null) - holders = holders.clone(); try { @@ -1391,16 +1406,6 @@ public class ServletHandler extends ScopedHandler LOG.debug("filterNameMap={} pathFilters={} servletFilterMap={} servletPathMap={} servletNameMap={}", _filterNameMap, _filterPathMappings, _filterNameMappings, _servletPathMap, _servletNameMap); } - - try - { - if (_contextHandler != null && _contextHandler.isStarted() || _contextHandler == null && isStarted()) - initialize(); - } - catch (Exception e) - { - throw new RuntimeException(e); - } } } @@ -1459,7 +1464,7 @@ public class ServletHandler extends ScopedHandler { updateBeans(_filterMappings,filterMappings); _filterMappings = filterMappings; - if (isStarted()) + if (isRunning()) updateMappings(); invalidateChainsCache(); } @@ -1469,12 +1474,8 @@ public class ServletHandler extends ScopedHandler try (AutoLock l = lock()) { if (holders != null) - { - for (FilterHolder holder : holders) - { - holder.setServletHandler(this); - } - } + initializeHolders(holders); + updateBeans(_filters,holders); _filters = holders; updateNameMappings(); @@ -1489,7 +1490,7 @@ public class ServletHandler extends ScopedHandler { updateBeans(_servletMappings,servletMappings); _servletMappings = servletMappings; - if (isStarted()) + if (isRunning()) updateMappings(); invalidateChainsCache(); } @@ -1504,12 +1505,7 @@ public class ServletHandler extends ScopedHandler try (AutoLock l = lock()) { if (holders != null) - { - for (ServletHolder holder : holders) - { - holder.setServletHandler(this); - } - } + initializeHolders(holders); updateBeans(_servlets,holders); _servlets = holders; updateNameMappings(); diff --git a/jetty-servlet/src/test/java/org/eclipse/jetty/servlet/ServletContextHandlerTest.java b/jetty-servlet/src/test/java/org/eclipse/jetty/servlet/ServletContextHandlerTest.java index 92315508e80..d94c224c63f 100644 --- a/jetty-servlet/src/test/java/org/eclipse/jetty/servlet/ServletContextHandlerTest.java +++ b/jetty-servlet/src/test/java/org/eclipse/jetty/servlet/ServletContextHandlerTest.java @@ -2007,38 +2007,179 @@ public class ServletContextHandlerTest } } + public static class TestPListener implements ServletRequestListener + { + @Override + public void requestInitialized(ServletRequestEvent sre) + { + ServletRequest request = sre.getServletRequest(); + Integer count = (Integer)request.getAttribute("testRequestListener"); + request.setAttribute("testRequestListener", count == null ? 1 : count + 1); + } + + @Override + public void requestDestroyed(ServletRequestEvent sre) + { + } + } + @Test - public void testProgrammaticFilterServlet() throws Exception + public void testProgrammaticListener() throws Exception { ServletContextHandler context = new ServletContextHandler(); ServletHandler handler = new ServletHandler(); _server.setHandler(context); context.setHandler(handler); - handler.addServletWithMapping(new ServletHolder(new TestServlet()), "/"); + // Add a servlet to report number of listeners + handler.addServletWithMapping(new ServletHolder(new HttpServlet() + { + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException + { + resp.getOutputStream().print("Listeners=" + req.getAttribute("testRequestListener")); + } + }), "/"); + + // Add a listener in STOPPED, STARTING and STARTED states + handler.addListener(new ListenerHolder(TestPListener.class)); + handler.addServlet(new ServletHolder(new HttpServlet() + { + @Override + public void init() throws ServletException + { + handler.addListener(new ListenerHolder(TestPListener.class)); + } + }) + { + { + setInitOrder(1); + } + }); _server.start(); - + handler.addListener(new ListenerHolder(TestPListener.class)); String request = - "GET /hello HTTP/1.0\n" + - "Host: localhost\n" + - "\n"; + "GET /test HTTP/1.0\n" + + "Host: localhost\n" + + "\n"; String response = _connector.getResponse(request); assertThat(response, containsString("200 OK")); - assertThat(response, containsString("Test")); + assertThat(response, containsString("Listeners=3")); + } - handler.addFilterWithMapping(new FilterHolder(new MyFilter()), "/*", EnumSet.of(DispatcherType.REQUEST)); - handler.addServletWithMapping(new ServletHolder(new HelloServlet()), "/hello/*"); + public static class TestPFilter implements Filter + { + @Override + public void init(FilterConfig filterConfig) throws ServletException + { + } - _server.dumpStdErr(); + @Override + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException + { + Integer count = (Integer)request.getAttribute("testFilter"); + request.setAttribute("testFilter", count == null ? 1 : count + 1); + chain.doFilter(request, response); + } - request = - "GET /hello HTTP/1.0\n" + - "Host: localhost\n" + - "\n"; + @Override + public void destroy() + { + } + } + + @Test + public void testProgrammaticFilters() throws Exception + { + ServletContextHandler context = new ServletContextHandler(); + ServletHandler handler = new ServletHandler(); + _server.setHandler(context); + context.setHandler(handler); + + // Add a servlet to report number of filters + handler.addServletWithMapping(new ServletHolder(new HttpServlet() + { + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException + { + resp.getOutputStream().print("Filters=" + req.getAttribute("testFilter")); + } + }), "/"); + + // Add a filter in STOPPED, STARTING and STARTED states + handler.addFilterWithMapping(new FilterHolder(TestPFilter.class), "/*", EnumSet.of(DispatcherType.REQUEST)); + handler.addServlet(new ServletHolder(new HttpServlet() + { + @Override + public void init() throws ServletException + { + handler.addFilterWithMapping(new FilterHolder(TestPFilter.class), "/*", EnumSet.of(DispatcherType.REQUEST)); + } + }) + { + { + setInitOrder(1); + } + }); + _server.start(); + handler.addFilterWithMapping(new FilterHolder(TestPFilter.class), "/*", EnumSet.of(DispatcherType.REQUEST)); + + String request = + "GET /test HTTP/1.0\n" + + "Host: localhost\n" + + "\n"; + String response = _connector.getResponse(request); + assertThat(response, containsString("200 OK")); + assertThat(response, containsString("Filters=3")); + } + + public static class TestPServlet extends HttpServlet + { + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException + { + resp.getOutputStream().println(req.getRequestURI()); + } + } + + @Test + public void testProgrammaticServlets() throws Exception + { + ServletContextHandler context = new ServletContextHandler(); + ServletHandler handler = new ServletHandler(); + _server.setHandler(context); + context.setHandler(handler); + + // Add a filter in STOPPED, STARTING and STARTED states + handler.addServletWithMapping(new ServletHolder(TestPServlet.class), "/one"); + handler.addServlet(new ServletHolder(new HttpServlet() + { + @Override + public void init() throws ServletException + { + handler.addServletWithMapping(new ServletHolder(TestPServlet.class), "/two"); + } + }) + { + { + setInitOrder(1); + } + }); + _server.start(); + handler.addServletWithMapping(new ServletHolder(TestPServlet.class), "/three"); + + String request = "GET /one HTTP/1.0\n" + "Host: localhost\n" + "\n"; + String response = _connector.getResponse(request); + assertThat(response, containsString("200 OK")); + assertThat(response, containsString("/one")); + request = "GET /two HTTP/1.0\n" + "Host: localhost\n" + "\n"; response = _connector.getResponse(request); assertThat(response, containsString("200 OK")); - assertThat(response, containsString("filter: filter")); - assertThat(response, containsString("Hello World")); + assertThat(response, containsString("/two")); + request = "GET /three HTTP/1.0\n" + "Host: localhost\n" + "\n"; + response = _connector.getResponse(request); + assertThat(response, containsString("200 OK")); + assertThat(response, containsString("/three")); } } \ No newline at end of file