feature to add servlet filters in a druid node via extension modules

This commit is contained in:
Himanshu Gupta 2015-03-03 22:35:02 -06:00
parent d8e199a3f5
commit a8648d8f3d
10 changed files with 214 additions and 1 deletions

View File

@ -17,8 +17,16 @@
package io.druid.server.initialization; package io.druid.server.initialization;
import java.util.Set;
import com.google.common.base.Joiner; import com.google.common.base.Joiner;
import com.google.inject.Injector;
import com.google.inject.Key;
import com.google.inject.TypeLiteral;
import com.metamx.common.ISE;
import org.eclipse.jetty.servlet.FilterHolder; import org.eclipse.jetty.servlet.FilterHolder;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.servlets.AsyncGzipFilter; import org.eclipse.jetty.servlets.AsyncGzipFilter;
import org.eclipse.jetty.servlets.GzipFilter; import org.eclipse.jetty.servlets.GzipFilter;
@ -51,4 +59,27 @@ public abstract class BaseJettyServerInitializer implements JettyServerInitializ
// We don't actually have any precomputed .gz resources, and checking for them inside jars is expensive. // We don't actually have any precomputed .gz resources, and checking for them inside jars is expensive.
filterHolder.setInitParameter("checkGzExists", String.valueOf(false)); filterHolder.setInitParameter("checkGzExists", String.valueOf(false));
} }
public void addExtensionFilters(ServletContextHandler handler, Injector injector) {
Set<ServletFilterHolder> extensionFilters = injector.getInstance(Key.get(new TypeLiteral<Set<ServletFilterHolder>>(){}));
for (ServletFilterHolder servletFilterHolder : extensionFilters) {
// Check the Filter first to guard against people who don't read the docs and return the Class even
// when they have an instance.
FilterHolder holder = null;
if (servletFilterHolder.getFilter() != null) {
holder = new FilterHolder(servletFilterHolder.getFilter());
} else if (servletFilterHolder.getFilterClass() != null) {
holder = new FilterHolder(servletFilterHolder.getFilterClass());
} else {
throw new ISE("Filter[%s] for path[%s] didn't have a Filter!?", servletFilterHolder, servletFilterHolder.getPath());
}
if(servletFilterHolder.getInitParameters() != null) {
holder.setInitParameters(servletFilterHolder.getInitParameters());
}
handler.addFilter(holder, servletFilterHolder.getPath(), servletFilterHolder.getDispatcherType());
}
}
} }

View File

@ -29,6 +29,7 @@ import com.google.inject.Provides;
import com.google.inject.ProvisionException; import com.google.inject.ProvisionException;
import com.google.inject.Scopes; import com.google.inject.Scopes;
import com.google.inject.Singleton; import com.google.inject.Singleton;
import com.google.inject.multibindings.Multibinder;
import com.metamx.common.lifecycle.Lifecycle; import com.metamx.common.lifecycle.Lifecycle;
import com.metamx.common.logger.Logger; import com.metamx.common.logger.Logger;
import com.sun.jersey.api.core.DefaultResourceConfig; import com.sun.jersey.api.core.DefaultResourceConfig;
@ -72,6 +73,10 @@ public class JettyServerModule extends JerseyServletModule
Jerseys.addResource(binder, StatusResource.class); Jerseys.addResource(binder, StatusResource.class);
binder.bind(StatusResource.class).in(LazySingleton.class); binder.bind(StatusResource.class).in(LazySingleton.class);
//Adding empty binding for ServletFilterHolders so that injector returns
//an empty set when no external modules provide ServletFilterHolder impls
Multibinder.newSetBinder(binder, ServletFilterHolder.class);
} }
public static class DruidGuiceContainer extends GuiceContainer public static class DruidGuiceContainer extends GuiceContainer

View File

@ -0,0 +1,79 @@
/*
* Licensed to Metamarkets Group Inc. (Metamarkets) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Metamarkets licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.druid.server.initialization;
import javax.servlet.DispatcherType;
import javax.servlet.Filter;
import java.util.EnumSet;
import java.util.Map;
/**
* A ServletFilterHolder is a class that holds all of the information required to attach a Filter to a Servlet.
*
* This largely exists just to make it possible to add Filters via Guice/DI and shouldn't really exist
* anywhere that is not initialization code.
*
* Note that some of the druid nodes (router for example) use async servlets and your filter
* implementation should be able to handle those requests properly.
*/
public interface ServletFilterHolder {
/**
* Get the Filter object that should be added to the servlet.
*
* This method is considered "mutually exclusive" from the getFilterClass method.
* That is, one of them should return null and the other should return an actual value.
*
* @return The Filter object to be added to the servlet
*/
public Filter getFilter();
/**
* Get the class of the Filter object that should be added to the servlet.
*
* This method is considered "mutually exclusive" from the getFilter method.
* That is, one of them should return null and the other should return an actual value.
*
* @return The class of the Filter object to be added to the servlet
*/
public Class<? extends Filter> getFilterClass();
/**
* Get Filter initialization parameters.
*
* @return a map containing all the Filter initialization
* parameters
*/
public Map<String,String> getInitParameters();
/**
* The path that this Filter should apply to
*
* @return the path that this Filter should apply to
*/
public String getPath();
/**
* The dispatcher type that this Filter should apply to
*
* @return the enumeration of DispatcherTypes that this Filter should apply to
*/
public EnumSet<DispatcherType> getDispatcherType();
}

View File

@ -25,6 +25,7 @@ import com.google.inject.Binder;
import com.google.inject.Injector; import com.google.inject.Injector;
import com.google.inject.Key; import com.google.inject.Key;
import com.google.inject.Module; import com.google.inject.Module;
import com.google.inject.multibindings.Multibinder;
import com.google.inject.servlet.GuiceFilter; import com.google.inject.servlet.GuiceFilter;
import com.metamx.common.lifecycle.Lifecycle; import com.metamx.common.lifecycle.Lifecycle;
import com.metamx.http.client.HttpClient; import com.metamx.http.client.HttpClient;
@ -49,7 +50,15 @@ import org.junit.After;
import org.junit.Before; import org.junit.Before;
import javax.net.ssl.SSLContext; import javax.net.ssl.SSLContext;
import javax.servlet.DispatcherType;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletOutputStream; import javax.servlet.ServletOutputStream;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.GET; import javax.ws.rs.GET;
import javax.ws.rs.POST; import javax.ws.rs.POST;
@ -59,6 +68,8 @@ import javax.ws.rs.core.Context;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import java.io.IOException; import java.io.IOException;
import java.util.EnumSet;
import java.util.Map;
import java.util.Random; import java.util.Random;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -101,6 +112,44 @@ public class BaseJettyTest
binder, Key.get(DruidNode.class, Self.class), new DruidNode("test", "localhost", null) binder, Key.get(DruidNode.class, Self.class), new DruidNode("test", "localhost", null)
); );
binder.bind(JettyServerInitializer.class).to(JettyServerInit.class).in(LazySingleton.class); binder.bind(JettyServerInitializer.class).to(JettyServerInit.class).in(LazySingleton.class);
Multibinder<ServletFilterHolder> multibinder = Multibinder.newSetBinder(binder, ServletFilterHolder.class);
multibinder.addBinding().toInstance(
new ServletFilterHolder()
{
@Override
public String getPath()
{
return "/*";
}
@Override
public Map<String, String> getInitParameters()
{
return null;
}
@Override
public Class<? extends Filter> getFilterClass()
{
return DummyAuthFilter.class;
}
@Override
public Filter getFilter()
{
return null;
}
@Override
public EnumSet<DispatcherType> getDispatcherType()
{
// TODO Auto-generated method stub
return null;
}
});
Jerseys.addResource(binder, SlowResource.class); Jerseys.addResource(binder, SlowResource.class);
Jerseys.addResource(binder, ExceptionResource.class); Jerseys.addResource(binder, ExceptionResource.class);
Jerseys.addResource(binder, DefaultResource.class); Jerseys.addResource(binder, DefaultResource.class);
@ -148,6 +197,7 @@ public class BaseJettyTest
{ {
final ServletContextHandler root = new ServletContextHandler(ServletContextHandler.SESSIONS); final ServletContextHandler root = new ServletContextHandler(ServletContextHandler.SESSIONS);
root.addServlet(new ServletHolder(new DefaultServlet()), "/*"); root.addServlet(new ServletHolder(new DefaultServlet()), "/*");
addExtensionFilters(root, injector);
root.addFilter(defaultGzipFilterHolder(), "/*", null); root.addFilter(defaultGzipFilterHolder(), "/*", null);
root.addFilter(GuiceFilter.class, "/*", null); root.addFilter(GuiceFilter.class, "/*", null);
@ -219,4 +269,33 @@ public class BaseJettyTest
throw new IOException(); throw new IOException();
} }
} }
public static class DummyAuthFilter implements Filter {
public static final String AUTH_HDR = "secretUser";
public static final String SECRET_USER = "bob";
@Override
public void init(FilterConfig filterConfig) throws ServletException
{
}
@Override
public void doFilter(ServletRequest req, ServletResponse resp, FilterChain chain) throws IOException,
ServletException
{
HttpServletRequest request = (HttpServletRequest) req;
if(request.getHeader(AUTH_HDR) == null || request.getHeader(AUTH_HDR).equals(SECRET_USER)) {
chain.doFilter(req, resp);
} else {
HttpServletResponse response = (HttpServletResponse) resp;
response.sendError(HttpServletResponse.SC_UNAUTHORIZED, "Failed even fake authentication.");
}
}
@Override
public void destroy()
{
}
}
} }

View File

@ -43,6 +43,8 @@ import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
import javax.servlet.http.HttpServletResponse;
public class JettyTest extends BaseJettyTest public class JettyTest extends BaseJettyTest
{ {
@ -179,4 +181,17 @@ public class JettyTest extends BaseJettyTest
latch.await(5, TimeUnit.SECONDS); latch.await(5, TimeUnit.SECONDS);
} }
@Test
public void testExtensionAuthFilter() throws Exception
{
URL url = new URL("http://localhost:" + port + "/default");
HttpURLConnection get = (HttpURLConnection) url.openConnection();
get.setRequestProperty(DummyAuthFilter.AUTH_HDR, DummyAuthFilter.SECRET_USER);
Assert.assertEquals(HttpServletResponse.SC_OK, get.getResponseCode());
get = (HttpURLConnection) url.openConnection();
get.setRequestProperty(DummyAuthFilter.AUTH_HDR, "hacker");
Assert.assertEquals(HttpServletResponse.SC_UNAUTHORIZED, get.getResponseCode());
}
} }

View File

@ -235,6 +235,7 @@ public class CliOverlord extends ServerRunnable
} }
) )
); );
addExtensionFilters(root, injector);
root.addFilter(defaultGzipFilterHolder(), "/*", null); root.addFilter(defaultGzipFilterHolder(), "/*", null);
// /status should not redirect, so add first // /status should not redirect, so add first

View File

@ -57,6 +57,7 @@ class CoordinatorJettyServerInitializer extends BaseJettyServerInitializer
} else { } else {
root.setResourceBase(config.getConsoleStatic()); root.setResourceBase(config.getConsoleStatic());
} }
addExtensionFilters(root, injector);
root.addFilter(defaultGzipFilterHolder(), "/*", null); root.addFilter(defaultGzipFilterHolder(), "/*", null);
// /status should not redirect, so add first // /status should not redirect, so add first

View File

@ -37,6 +37,7 @@ class MiddleManagerJettyServerInitializer extends BaseJettyServerInitializer
{ {
final ServletContextHandler root = new ServletContextHandler(ServletContextHandler.SESSIONS); final ServletContextHandler root = new ServletContextHandler(ServletContextHandler.SESSIONS);
root.addServlet(new ServletHolder(new DefaultServlet()), "/*"); root.addServlet(new ServletHolder(new DefaultServlet()), "/*");
addExtensionFilters(root, injector);
root.addFilter(defaultGzipFilterHolder(), "/*", null); root.addFilter(defaultGzipFilterHolder(), "/*", null);
root.addFilter(GuiceFilter.class, "/*", null); root.addFilter(GuiceFilter.class, "/*", null);

View File

@ -36,6 +36,7 @@ public class QueryJettyServerInitializer extends BaseJettyServerInitializer
{ {
final ServletContextHandler root = new ServletContextHandler(ServletContextHandler.SESSIONS); final ServletContextHandler root = new ServletContextHandler(ServletContextHandler.SESSIONS);
root.addServlet(new ServletHolder(new DefaultServlet()), "/*"); root.addServlet(new ServletHolder(new DefaultServlet()), "/*");
addExtensionFilters(root, injector);
root.addFilter(defaultGzipFilterHolder(), "/*", null); root.addFilter(defaultGzipFilterHolder(), "/*", null);
root.addFilter(GuiceFilter.class, "/*", null); root.addFilter(GuiceFilter.class, "/*", null);

View File

@ -88,8 +88,8 @@ public class RouterJettyServerInitializer extends BaseJettyServerInitializer
requestLogger requestLogger
); );
asyncQueryForwardingServlet.setTimeout(httpClientConfig.getReadTimeout().getMillis()); asyncQueryForwardingServlet.setTimeout(httpClientConfig.getReadTimeout().getMillis());
root.addServlet(new ServletHolder(asyncQueryForwardingServlet), "/druid/v2/*"); root.addServlet(new ServletHolder(asyncQueryForwardingServlet), "/druid/v2/*");
addExtensionFilters(root, injector);
root.addFilter(defaultAsyncGzipFilterHolder(), "/*", null); root.addFilter(defaultAsyncGzipFilterHolder(), "/*", null);
// Can't use '/*' here because of Guice conflicts with AsyncQueryForwardingServlet path // Can't use '/*' here because of Guice conflicts with AsyncQueryForwardingServlet path
root.addFilter(GuiceFilter.class, "/status/*", null); root.addFilter(GuiceFilter.class, "/status/*", null);