Issue #4022 Prevent Servlet adding another Servlet (#4024)

* Issue #4022  Prevent Servlet adding Servlet and added unit tests.

Signed-off-by: Jan Bartel <janb@webtide.com>
This commit is contained in:
Jan Bartel 2019-08-28 12:29:14 +10:00 committed by GitHub
parent f4d95e0f2f
commit 4251a3e092
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 444 additions and 10 deletions

View File

@ -1067,10 +1067,13 @@ public class ServletContextHandler extends ContextHandler
return new Dispatcher(context, name); return new Dispatcher(context, name);
} }
private void checkDynamicName(String name) private void checkDynamic(String name)
{ {
if (isStarted()) if (isStarted())
throw new IllegalStateException(); throw new IllegalStateException();
if (ServletContextHandler.this.getServletHandler().isInitialized())
throw new IllegalStateException();
if (StringUtil.isBlank(name)) if (StringUtil.isBlank(name))
throw new IllegalStateException("Missing name"); throw new IllegalStateException("Missing name");
@ -1085,7 +1088,7 @@ public class ServletContextHandler extends ContextHandler
@Override @Override
public FilterRegistration.Dynamic addFilter(String filterName, Class<? extends Filter> filterClass) public FilterRegistration.Dynamic addFilter(String filterName, Class<? extends Filter> filterClass)
{ {
checkDynamicName(filterName); checkDynamic(filterName);
final ServletHandler handler = ServletContextHandler.this.getServletHandler(); final ServletHandler handler = ServletContextHandler.this.getServletHandler();
FilterHolder holder = handler.getFilter(filterName); FilterHolder holder = handler.getFilter(filterName);
@ -1114,7 +1117,7 @@ public class ServletContextHandler extends ContextHandler
@Override @Override
public FilterRegistration.Dynamic addFilter(String filterName, String className) public FilterRegistration.Dynamic addFilter(String filterName, String className)
{ {
checkDynamicName(filterName); checkDynamic(filterName);
final ServletHandler handler = ServletContextHandler.this.getServletHandler(); final ServletHandler handler = ServletContextHandler.this.getServletHandler();
FilterHolder holder = handler.getFilter(filterName); FilterHolder holder = handler.getFilter(filterName);
@ -1143,7 +1146,7 @@ public class ServletContextHandler extends ContextHandler
@Override @Override
public FilterRegistration.Dynamic addFilter(String filterName, Filter filter) public FilterRegistration.Dynamic addFilter(String filterName, Filter filter)
{ {
checkDynamicName(filterName); checkDynamic(filterName);
final ServletHandler handler = ServletContextHandler.this.getServletHandler(); final ServletHandler handler = ServletContextHandler.this.getServletHandler();
FilterHolder holder = handler.getFilter(filterName); FilterHolder holder = handler.getFilter(filterName);
@ -1173,7 +1176,7 @@ public class ServletContextHandler extends ContextHandler
@Override @Override
public ServletRegistration.Dynamic addServlet(String servletName, Class<? extends Servlet> servletClass) public ServletRegistration.Dynamic addServlet(String servletName, Class<? extends Servlet> servletClass)
{ {
checkDynamicName(servletName); checkDynamic(servletName);
final ServletHandler handler = ServletContextHandler.this.getServletHandler(); final ServletHandler handler = ServletContextHandler.this.getServletHandler();
ServletHolder holder = handler.getServlet(servletName); ServletHolder holder = handler.getServlet(servletName);
@ -1203,7 +1206,7 @@ public class ServletContextHandler extends ContextHandler
@Override @Override
public ServletRegistration.Dynamic addServlet(String servletName, String className) public ServletRegistration.Dynamic addServlet(String servletName, String className)
{ {
checkDynamicName(servletName); checkDynamic(servletName);
final ServletHandler handler = ServletContextHandler.this.getServletHandler(); final ServletHandler handler = ServletContextHandler.this.getServletHandler();
ServletHolder holder = handler.getServlet(servletName); ServletHolder holder = handler.getServlet(servletName);
@ -1233,7 +1236,7 @@ public class ServletContextHandler extends ContextHandler
@Override @Override
public ServletRegistration.Dynamic addServlet(String servletName, Servlet servlet) public ServletRegistration.Dynamic addServlet(String servletName, Servlet servlet)
{ {
checkDynamicName(servletName); checkDynamic(servletName);
final ServletHandler handler = ServletContextHandler.this.getServletHandler(); final ServletHandler handler = ServletContextHandler.this.getServletHandler();
ServletHolder holder = handler.getServlet(servletName); ServletHolder holder = handler.getServlet(servletName);

View File

@ -116,6 +116,7 @@ public class ServletHandler extends ScopedHandler
private PathMappings<ServletHolder> _servletPathMap; private PathMappings<ServletHolder> _servletPathMap;
private ListenerHolder[] _listeners = new ListenerHolder[0]; private ListenerHolder[] _listeners = new ListenerHolder[0];
private boolean _initialized = false;
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
protected final ConcurrentMap<String, FilterChain>[] _chainCache = new ConcurrentMap[FilterMapping.ALL]; protected final ConcurrentMap<String, FilterChain>[] _chainCache = new ConcurrentMap[FilterMapping.ALL];
@ -331,6 +332,7 @@ public class ServletHandler extends ScopedHandler
_filterPathMappings = null; _filterPathMappings = null;
_filterNameMappings = null; _filterNameMappings = null;
_servletPathMap = null; _servletPathMap = null;
_initialized = false;
} }
protected IdentityService getIdentityService() protected IdentityService getIdentityService()
@ -730,6 +732,8 @@ public class ServletHandler extends ScopedHandler
public void initialize() public void initialize()
throws Exception throws Exception
{ {
_initialized = true;
MultiException mx = new MultiException(); MultiException mx = new MultiException();
Stream.concat(Stream.concat( Stream.concat(Stream.concat(
@ -755,6 +759,14 @@ public class ServletHandler extends ScopedHandler
mx.ifExceptionThrow(); mx.ifExceptionThrow();
} }
/**
* @return true if initialized has been called, false otherwise
*/
public boolean isInitialized()
{
return _initialized;
}
/** /**
* @return whether the filter chains are cached. * @return whether the filter chains are cached.

View File

@ -22,6 +22,7 @@ import java.io.IOException;
import java.io.PrintWriter; import java.io.PrintWriter;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.EnumSet;
import java.util.EventListener; import java.util.EventListener;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
@ -29,6 +30,11 @@ import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import javax.servlet.DispatcherType;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.FilterRegistration;
import javax.servlet.Servlet; import javax.servlet.Servlet;
import javax.servlet.ServletContainerInitializer; import javax.servlet.ServletContainerInitializer;
import javax.servlet.ServletContext; import javax.servlet.ServletContext;
@ -37,10 +43,13 @@ import javax.servlet.ServletContextAttributeListener;
import javax.servlet.ServletContextEvent; import javax.servlet.ServletContextEvent;
import javax.servlet.ServletContextListener; import javax.servlet.ServletContextListener;
import javax.servlet.ServletException; import javax.servlet.ServletException;
import javax.servlet.ServletRegistration;
import javax.servlet.ServletRequest;
import javax.servlet.ServletRequestAttributeEvent; import javax.servlet.ServletRequestAttributeEvent;
import javax.servlet.ServletRequestAttributeListener; import javax.servlet.ServletRequestAttributeListener;
import javax.servlet.ServletRequestEvent; import javax.servlet.ServletRequestEvent;
import javax.servlet.ServletRequestListener; import javax.servlet.ServletRequestListener;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
@ -71,6 +80,9 @@ import org.eclipse.jetty.server.session.SessionHandler;
import org.eclipse.jetty.util.DecoratedObjectFactory; import org.eclipse.jetty.util.DecoratedObjectFactory;
import org.eclipse.jetty.util.Decorator; import org.eclipse.jetty.util.Decorator;
import org.eclipse.jetty.util.component.AbstractLifeCycle; import org.eclipse.jetty.util.component.AbstractLifeCycle;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.util.log.StacklessLogging;
import org.hamcrest.Matchers; import org.hamcrest.Matchers;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
@ -110,12 +122,13 @@ public class ServletContextHandlerTest
public static class MySCIStarter extends AbstractLifeCycle implements ServletContextHandler.ServletContainerInitializerCaller public static class MySCIStarter extends AbstractLifeCycle implements ServletContextHandler.ServletContainerInitializerCaller
{ {
MySCI _sci = new MySCI(); ServletContainerInitializer _sci = null;
ContextHandler.Context _ctx; ContextHandler.Context _ctx;
MySCIStarter(ContextHandler.Context ctx) MySCIStarter(ContextHandler.Context ctx, ServletContainerInitializer sci)
{ {
_ctx = ctx; _ctx = ctx;
_sci = sci;
} }
@Override @Override
@ -417,7 +430,7 @@ public class ServletContextHandlerTest
_server.setHandler(contexts); _server.setHandler(contexts);
ServletContextHandler root = new ServletContextHandler(contexts, "/"); ServletContextHandler root = new ServletContextHandler(contexts, "/");
root.addBean(new MySCIStarter(root.getServletContext()), true); root.addBean(new MySCIStarter(root.getServletContext(), new MySCI()), true);
_server.start(); _server.start();
assertTrue((Boolean)root.getServletContext().getAttribute("MySCI.startup")); assertTrue((Boolean)root.getServletContext().getAttribute("MySCI.startup"));
assertTrue((Boolean)root.getServletContext().getAttribute("MyContextListener.contextInitialized")); assertTrue((Boolean)root.getServletContext().getAttribute("MyContextListener.contextInitialized"));
@ -595,6 +608,180 @@ public class ServletContextHandlerTest
assertEquals(0, __testServlets.get()); assertEquals(0, __testServlets.get());
} }
@Test
public void testAddServletFromServlet() throws Exception
{
//A servlet cannot be added by another servlet
Logger logger = Log.getLogger(ContextHandler.class.getName() + "ROOT");
try (StacklessLogging stackless = new StacklessLogging(logger))
{
ServletContextHandler context = new ServletContextHandler();
context.setLogger(logger);
ServletHolder holder = context.addServlet(ServletAddingServlet.class, "/start");
context.getServletHandler().setStartWithUnavailable(false);
holder.setInitOrder(0);
context.setContextPath("/");
_server.setHandler(context);
_server.start();
fail("Servlet can only be added from SCI or SCL");
}
catch (Exception e)
{
if (e instanceof ServletException)
{
assertTrue(e.getCause() instanceof IllegalStateException);
}
else
fail(e);
}
}
@Test
public void testAddFilterFromServlet() throws Exception
{
//A filter cannot be added from a servlet
Logger logger = Log.getLogger(ContextHandler.class.getName() + "ROOT");
try (StacklessLogging stackless = new StacklessLogging(logger))
{
ServletContextHandler context = new ServletContextHandler();
context.setLogger(logger);
ServletHolder holder = context.addServlet(FilterAddingServlet.class, "/filter");
context.getServletHandler().setStartWithUnavailable(false);
holder.setInitOrder(0);
context.setContextPath("/");
_server.setHandler(context);
_server.start();
fail("Filter can only be added from SCI or SCL");
}
catch (Exception e)
{
if (e instanceof ServletException)
{
assertTrue(e.getCause() instanceof IllegalStateException);
}
else
fail(e);
}
}
@Test
public void testAddServletFromFilter() throws Exception
{
//A servlet cannot be added from a Filter
Logger logger = Log.getLogger(ContextHandler.class.getName() + "ROOT");
try (StacklessLogging stackless = new StacklessLogging(logger))
{
ServletContextHandler context = new ServletContextHandler();
context.setLogger(logger);
FilterHolder holder = new FilterHolder(new Filter()
{
@Override
public void init(FilterConfig filterConfig) throws ServletException
{
ServletRegistration rego = filterConfig.getServletContext().addServlet("hello", HelloServlet.class);
rego.addMapping("/hello/*");
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException
{
}
@Override
public void destroy()
{
}
});
context.addFilter(holder, "/*", EnumSet.of(DispatcherType.REQUEST));
context.getServletHandler().setStartWithUnavailable(false);
context.setContextPath("/");
_server.setHandler(context);
_server.start();
fail("Servlet can only be added from SCI or SCL");
}
catch (Exception e)
{
if (!(e instanceof IllegalStateException))
{
if (e instanceof ServletException)
{
assertTrue(e.getCause() instanceof IllegalStateException);
}
else
fail(e);
}
}
}
@Test
public void testAddServletFromSCL() throws Exception
{
//A servlet can be added from a ServletContextListener
ServletContextHandler context = new ServletContextHandler();
context.getServletHandler().setStartWithUnavailable(false);
context.setContextPath("/");
context.addEventListener(new ServletContextListener()
{
@Override
public void contextInitialized(ServletContextEvent sce)
{
ServletRegistration rego = sce.getServletContext().addServlet("hello", HelloServlet.class);
rego.addMapping("/hello/*");
}
@Override
public void contextDestroyed(ServletContextEvent sce)
{
}
});
_server.setHandler(context);
_server.start();
StringBuffer request = new StringBuffer();
request.append("GET /hello HTTP/1.0\n");
request.append("Host: localhost\n");
request.append("\n");
String response = _connector.getResponse(request.toString());
assertThat("Response", response, containsString("Hello World"));
}
@Test
public void testAddServletFromSCI() throws Exception
{
//A servlet can be added from a ServletContainerInitializer
ContextHandlerCollection contexts = new ContextHandlerCollection();
_server.setHandler(contexts);
ServletContextHandler root = new ServletContextHandler(contexts, "/");
class ServletAddingSCI implements ServletContainerInitializer
{
@Override
public void onStartup(Set<Class<?>> c, ServletContext ctx) throws ServletException
{
ServletRegistration rego = ctx.addServlet("hello", HelloServlet.class);
rego.addMapping("/hello/*");
}
}
root.addBean(new MySCIStarter(root.getServletContext(), new ServletAddingSCI()), true);
_server.start();
StringBuffer request = new StringBuffer();
request.append("GET /hello HTTP/1.0\n");
request.append("Host: localhost\n");
request.append("\n");
String response = _connector.getResponse(request.toString());
assertThat("Response", response, containsString("Hello World"));
}
@Test @Test
public void testAddServletAfterStart() throws Exception public void testAddServletAfterStart() throws Exception
{ {
@ -623,7 +810,126 @@ public class ServletContextHandlerTest
response = _connector.getResponse(request.toString()); response = _connector.getResponse(request.toString());
assertThat("Response", response, containsString("Hello World")); assertThat("Response", response, containsString("Hello World"));
} }
@Test
public void testServletRegistrationByClass() throws Exception
{
ServletContextHandler context = new ServletContextHandler();
context.setContextPath("/");
ServletRegistration reg = context.getServletContext().addServlet("test", TestServlet.class);
reg.addMapping("/test");
_server.setHandler(context);
_server.start();
StringBuffer request = new StringBuffer();
request.append("GET /test HTTP/1.0\n");
request.append("Host: localhost\n");
request.append("\n");
String response = _connector.getResponse(request.toString());
assertThat("Response", response, containsString("Test"));
}
@Test
public void testServletRegistrationByClassName() throws Exception
{
ServletContextHandler context = new ServletContextHandler();
context.setContextPath("/");
ServletRegistration reg = context.getServletContext().addServlet("test", TestServlet.class.getName());
reg.addMapping("/test");
_server.setHandler(context);
_server.start();
StringBuffer request = new StringBuffer();
request.append("GET /test HTTP/1.0\n");
request.append("Host: localhost\n");
request.append("\n");
String response = _connector.getResponse(request.toString());
assertThat("Response", response, containsString("Test"));
}
@Test
public void testPartialServletRegistrationByName() throws Exception
{
ServletContextHandler context = new ServletContextHandler();
context.setContextPath("/");
ServletHolder partial = new ServletHolder();
partial.setName("test");
context.addServlet(partial, "/test");
//complete partial servlet registration by providing name of the servlet class
ServletRegistration reg = context.getServletContext().addServlet("test", TestServlet.class.getName());
assertNotNull(reg);
assertEquals(TestServlet.class.getName(), partial.getClassName());
_server.setHandler(context);
_server.start();
StringBuffer request = new StringBuffer();
request.append("GET /test HTTP/1.0\n");
request.append("Host: localhost\n");
request.append("\n");
String response = _connector.getResponse(request.toString());
assertThat("Response", response, containsString("Test"));
}
@Test
public void testPartialServletRegistrationByClass() throws Exception
{
ServletContextHandler context = new ServletContextHandler();
context.setContextPath("/");
ServletHolder partial = new ServletHolder();
partial.setName("test");
context.addServlet(partial, "/test");
//complete partial servlet registration by providing the servlet class
ServletRegistration reg = context.getServletContext().addServlet("test", TestServlet.class);
assertNotNull(reg);
assertEquals(TestServlet.class.getName(), partial.getClassName());
assertSame(TestServlet.class, partial.getHeldClass());
_server.setHandler(context);
_server.start();
StringBuffer request = new StringBuffer();
request.append("GET /test HTTP/1.0\n");
request.append("Host: localhost\n");
request.append("\n");
String response = _connector.getResponse(request.toString());
assertThat("Response", response, containsString("Test"));
}
@Test
public void testNullServletRegistration() throws Exception
{
ServletContextHandler context = new ServletContextHandler();
context.setContextPath("/");
ServletHolder full = new ServletHolder();
full.setName("test");
full.setHeldClass(TestServlet.class);
context.addServlet(full, "/test");
//Must return null if the servlet has been fully defined previously
ServletRegistration reg = context.getServletContext().addServlet("test", TestServlet.class);
assertNull(reg);
_server.setHandler(context);
_server.start();
StringBuffer request = new StringBuffer();
request.append("GET /test HTTP/1.0\n");
request.append("Host: localhost\n");
request.append("\n");
String response = _connector.getResponse(request.toString());
assertThat("Response", response, containsString("Test"));
}
@Test @Test
public void testHandlerBeforeServletHandler() throws Exception public void testHandlerBeforeServletHandler() throws Exception
{ {
@ -948,6 +1254,28 @@ public class ServletContextHandlerTest
writer.write("Hello World"); writer.write("Hello World");
} }
} }
public static class MyFilter implements Filter
{
@Override
public void init(FilterConfig filterConfig) throws ServletException
{
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException,
ServletException
{
request.getServletContext().setAttribute("filter", "filter");
chain.doFilter(request, response);
}
@Override
public void destroy()
{
}
}
public static class DummyUtilDecorator implements org.eclipse.jetty.util.Decorator public static class DummyUtilDecorator implements org.eclipse.jetty.util.Decorator
{ {
@ -1009,6 +1337,50 @@ public class ServletContextHandlerTest
} }
} }
public static class ServletAddingServlet extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException
{
resp.getWriter().write("Start");
resp.getWriter().close();
}
@Override
public void init() throws ServletException
{
ServletRegistration dynamic = getServletContext().addServlet("added", AddedServlet.class);
dynamic.addMapping("/added/*");
}
}
public static class FilterAddingServlet extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException
{
resp.getWriter().write("Filter");
resp.getWriter().close();
}
@Override
public void init() throws ServletException
{
FilterRegistration dynamic = getServletContext().addFilter("filter", new MyFilter());
dynamic.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
}
}
public static class AddedServlet extends HttpServlet
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException
{
resp.getWriter().write("Added");
resp.getWriter().close();
}
}
public static class TestServlet extends HttpServlet public static class TestServlet extends HttpServlet
{ {
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 1L;

View File

@ -18,7 +18,11 @@
package org.eclipse.jetty.servlet; package org.eclipse.jetty.servlet;
import javax.servlet.ServletException;
import javax.servlet.UnavailableException; import javax.servlet.UnavailableException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.eclipse.jetty.server.handler.ContextHandler; import org.eclipse.jetty.server.handler.ContextHandler;
import org.eclipse.jetty.util.MultiException; import org.eclipse.jetty.util.MultiException;
@ -31,8 +35,15 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
import java.io.IOException;
public class ServletHolderTest public class ServletHolderTest
{ {
public static class FakeServlet extends HttpServlet
{
}
@Test @Test
public void testTransitiveCompareTo() throws Exception public void testTransitiveCompareTo() throws Exception
@ -113,6 +124,42 @@ public class ServletHolderTest
assertThat(e.getCause().getMessage(), containsString("foo")); assertThat(e.getCause().getMessage(), containsString("foo"));
} }
} }
@Test
public void testWithClass() throws Exception
{
//Test adding servlet by class
try (StacklessLogging stackless = new StacklessLogging(BaseHolder.class, ServletHandler.class, ContextHandler.class, ServletContextHandler.class))
{
ServletContextHandler context = new ServletContextHandler();
ServletHandler handler = context.getServletHandler();
ServletHolder holder = new ServletHolder();
holder.setName("foo");
holder.setHeldClass(FakeServlet.class);
handler.addServlet(holder);
handler.start();
assertTrue(holder.isAvailable());
assertTrue(holder.isStarted());
}
}
@Test
public void testWithClassName() throws Exception
{
//Test adding servlet by classname
try (StacklessLogging stackless = new StacklessLogging(BaseHolder.class, ServletHandler.class, ContextHandler.class, ServletContextHandler.class))
{
ServletContextHandler context = new ServletContextHandler();
ServletHandler handler = context.getServletHandler();
ServletHolder holder = new ServletHolder();
holder.setName("foo");
holder.setClassName("org.eclipse.jetty.servlet.ServletHolderTest$FakeServlet");
handler.addServlet(holder);
handler.start();
assertTrue(holder.isAvailable());
assertTrue(holder.isStarted());
}
}
@Test @Test
public void testUnloadableClassName() throws Exception public void testUnloadableClassName() throws Exception