This commit is contained in:
Jan Bartel 2017-07-13 16:03:58 +02:00 committed by Joakim Erdfelt
parent e5f7fee279
commit e39c66d425
3 changed files with 140 additions and 21 deletions

View File

@ -147,6 +147,7 @@ public class DoSFilter implements Filter
private static final long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM = 30000L; private static final long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM = 30000L;
private static final long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM = 30000L; private static final long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM = 30000L;
static final String NAME = "name";
static final String MANAGED_ATTR_INIT_PARAM = "managedAttr"; static final String MANAGED_ATTR_INIT_PARAM = "managedAttr";
static final String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec"; static final String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
static final String DELAY_MS_INIT_PARAM = "delayMs"; static final String DELAY_MS_INIT_PARAM = "delayMs";
@ -181,12 +182,14 @@ public class DoSFilter implements Filter
private volatile boolean _trackSessions; private volatile boolean _trackSessions;
private volatile boolean _remotePort; private volatile boolean _remotePort;
private volatile boolean _enabled; private volatile boolean _enabled;
private volatile String _name;
private Semaphore _passes; private Semaphore _passes;
private volatile int _throttledRequests; private volatile int _throttledRequests;
private volatile int _maxRequestsPerSec; private volatile int _maxRequestsPerSec;
private Queue<AsyncContext>[] _queues; private Queue<AsyncContext>[] _queues;
private AsyncListener[] _listeners; private AsyncListener[] _listeners;
private Scheduler _scheduler; private Scheduler _scheduler;
private ServletContext _context;
public void init(FilterConfig filterConfig) throws ServletException public void init(FilterConfig filterConfig) throws ServletException
{ {
@ -263,11 +266,14 @@ public class DoSFilter implements Filter
parameter = filterConfig.getInitParameter(TOO_MANY_CODE); parameter = filterConfig.getInitParameter(TOO_MANY_CODE);
setTooManyCode(parameter==null?429:Integer.parseInt(parameter)); setTooManyCode(parameter==null?429:Integer.parseInt(parameter));
_scheduler = startScheduler(); setName(filterConfig.getFilterName());
_context = filterConfig.getServletContext();
if (_context != null )
{
_context.setAttribute(filterConfig.getFilterName(), this);
}
ServletContext context = filterConfig.getServletContext(); _scheduler = startScheduler();
if (context != null && Boolean.parseBoolean(filterConfig.getInitParameter(MANAGED_ATTR_INIT_PARAM)))
context.setAttribute(filterConfig.getFilterName(), this);
} }
protected Scheduler startScheduler() throws ServletException protected Scheduler startScheduler() throws ServletException
@ -537,6 +543,11 @@ public class DoSFilter implements Filter
return USER_AUTH; return USER_AUTH;
} }
public void schedule (RateTracker tracker)
{
_scheduler.schedule(tracker, getMaxIdleTrackerMs(), TimeUnit.MILLISECONDS);
}
/** /**
* Return a request rate tracker associated with this connection; keeps * Return a request rate tracker associated with this connection; keeps
* track of this connection's request rate. If this is not the first request * track of this connection's request rate. If this is not the first request
@ -583,8 +594,9 @@ public class DoSFilter implements Filter
{ {
boolean allowed = checkWhitelist(request.getRemoteAddr()); boolean allowed = checkWhitelist(request.getRemoteAddr());
int maxRequestsPerSec = getMaxRequestsPerSec(); int maxRequestsPerSec = getMaxRequestsPerSec();
tracker = allowed ? new FixedRateTracker(loadId, type, maxRequestsPerSec) tracker = allowed ? new FixedRateTracker(_context, _name, loadId, type, maxRequestsPerSec)
: new RateTracker(loadId, type, maxRequestsPerSec); : new RateTracker(_context,_name, loadId, type, maxRequestsPerSec);
tracker.setContext(_context);
RateTracker existing = _rateTrackers.putIfAbsent(loadId, tracker); RateTracker existing = _rateTrackers.putIfAbsent(loadId, tracker);
if (existing != null) if (existing != null)
tracker = existing; tracker = existing;
@ -604,6 +616,16 @@ public class DoSFilter implements Filter
return tracker; return tracker;
} }
public void addToRateTracker (RateTracker tracker)
{
_rateTrackers.put(tracker.getId(), tracker);
}
public void removeFromRateTracker (String id)
{
_rateTrackers.remove(id);
}
protected boolean checkWhitelist(String candidate) protected boolean checkWhitelist(String candidate)
{ {
for (String address : _whitelist) for (String address : _whitelist)
@ -931,6 +953,25 @@ public class DoSFilter implements Filter
_maxIdleTrackerMs = value; _maxIdleTrackerMs = value;
} }
/**
* The unique name of the filter when there is more than
* one DosFilter instance.
*
* @return the name
*/
public String getName()
{
return _name;
}
/**
* @param name the name to set
*/
public void setName(String name)
{
_name = name;
}
/** /**
* Check flag to insert the DoSFilter headers into the response. * Check flag to insert the DoSFilter headers into the response.
* *
@ -1103,17 +1144,22 @@ public class DoSFilter implements Filter
* A RateTracker is associated with a connection, and stores request rate * A RateTracker is associated with a connection, and stores request rate
* data. * data.
*/ */
class RateTracker implements Runnable, HttpSessionBindingListener, HttpSessionActivationListener, Serializable static class RateTracker implements Runnable, HttpSessionBindingListener, HttpSessionActivationListener, Serializable
{ {
private static final long serialVersionUID = 3534663738034577872L; private static final long serialVersionUID = 3534663738034577872L;
protected final String _filterName;
protected transient ServletContext _context;
protected final String _id; protected final String _id;
protected final int _type; protected final int _type;
protected final long[] _timestamps; protected final long[] _timestamps;
protected int _next; protected int _next;
public RateTracker(String id, int type, int maxRequestsPerSecond) public RateTracker(ServletContext context, String filterName, String id, int type, int maxRequestsPerSecond)
{ {
_context = context;
_filterName = filterName;
_id = id; _id = id;
_type = type; _type = type;
_timestamps = new long[maxRequestsPerSecond]; _timestamps = new long[maxRequestsPerSecond];
@ -1151,40 +1197,87 @@ public class DoSFilter implements Filter
{ {
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("Value bound: {}", getId()); LOG.debug("Value bound: {}", getId());
_context = event.getSession().getServletContext();
} }
public void valueUnbound(HttpSessionBindingEvent event) public void valueUnbound(HttpSessionBindingEvent event)
{ {
//take the tracker out of the list of trackers //take the tracker out of the list of trackers
_rateTrackers.remove(_id); DoSFilter filter = (DoSFilter)event.getSession().getServletContext().getAttribute(_filterName);
if (LOG.isDebugEnabled()) removeFromRateTrackers(filter, _id);
LOG.debug("Tracker removed: {}", getId()); _context = null;
} }
public void sessionWillPassivate(HttpSessionEvent se) public void sessionWillPassivate(HttpSessionEvent se)
{ {
//take the tracker of the list of trackers (if its still there) //take the tracker of the list of trackers (if its still there)
_rateTrackers.remove(_id); DoSFilter filter = (DoSFilter)se.getSession().getServletContext().getAttribute(_filterName);
removeFromRateTrackers(filter, _id);
_context = null;
} }
public void sessionDidActivate(HttpSessionEvent se) public void sessionDidActivate(HttpSessionEvent se)
{ {
RateTracker tracker = (RateTracker)se.getSession().getAttribute(__TRACKER); RateTracker tracker = (RateTracker)se.getSession().getAttribute(__TRACKER);
if (tracker!=null) ServletContext context = se.getSession().getServletContext();
_rateTrackers.put(tracker.getId(),tracker); tracker.setContext(context);
DoSFilter filter = (DoSFilter)context.getAttribute(_filterName);
if (filter == null)
{
LOG.info("No filter {} for rate tracker {}", _filterName, tracker);
return;
}
addToRateTrackers(filter, tracker);
}
public void setContext (ServletContext context)
{
_context = context;
}
protected void removeFromRateTrackers (DoSFilter filter, String id)
{
if (filter == null)
return;
filter.removeFromRateTracker(id);
if (LOG.isDebugEnabled())
LOG.debug("Tracker removed: {}", getId());
}
protected void addToRateTrackers (DoSFilter filter, RateTracker tracker)
{
if (filter == null)
return;
filter.addToRateTracker(tracker);
} }
@Override @Override
public void run() public void run()
{ {
if (_context == null)
{
LOG.warn("Unknkown context for rate tracker {}", this);
return;
}
int latestIndex = _next == 0 ? (_timestamps.length - 1) : (_next - 1); int latestIndex = _next == 0 ? (_timestamps.length - 1) : (_next - 1);
long last = _timestamps[latestIndex]; long last = _timestamps[latestIndex];
boolean hasRecentRequest = last != 0 && (System.currentTimeMillis() - last) < 1000L; boolean hasRecentRequest = last != 0 && (System.currentTimeMillis() - last) < 1000L;
DoSFilter filter = (DoSFilter)_context.getAttribute(_filterName);
if (hasRecentRequest) if (hasRecentRequest)
_scheduler.schedule(this, getMaxIdleTrackerMs(), TimeUnit.MILLISECONDS); {
if (filter != null)
filter.schedule(this);
else
LOG.warn("No filter {}", _filterName);
}
else else
_rateTrackers.remove(_id); removeFromRateTrackers(filter, _id);
} }
@Override @Override
@ -1196,9 +1289,9 @@ public class DoSFilter implements Filter
class FixedRateTracker extends RateTracker class FixedRateTracker extends RateTracker
{ {
public FixedRateTracker(String id, int type, int numRecentRequestsTracked) public FixedRateTracker(ServletContext context, String filterName, String id, int type, int numRecentRequestsTracked)
{ {
super(id, type, numRecentRequestsTracked); super(context, filterName, id, type, numRecentRequestsTracked);
} }
@Override @Override

View File

@ -18,11 +18,13 @@
package org.eclipse.jetty.servlets; package org.eclipse.jetty.servlets;
import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.net.Socket; import java.net.Socket;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.EnumSet; import java.util.EnumSet;
import javax.servlet.DispatcherType; import javax.servlet.DispatcherType;
@ -34,12 +36,17 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import org.eclipse.jetty.http.HttpURI; import org.eclipse.jetty.http.HttpURI;
import org.eclipse.jetty.server.session.DefaultSessionCache;
import org.eclipse.jetty.server.session.FileSessionDataStore;
import org.eclipse.jetty.servlet.FilterHolder; import org.eclipse.jetty.servlet.FilterHolder;
import org.eclipse.jetty.servlet.ServletTester; import org.eclipse.jetty.servlet.ServletTester;
import org.eclipse.jetty.toolchain.test.FS;
import org.eclipse.jetty.toolchain.test.TestingDir;
import org.eclipse.jetty.util.IO; import org.eclipse.jetty.util.IO;
import org.hamcrest.Matchers; import org.hamcrest.Matchers;
import org.junit.After; import org.junit.After;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
public abstract class AbstractDoSFilterTest public abstract class AbstractDoSFilterTest
@ -49,9 +56,23 @@ public abstract class AbstractDoSFilterTest
protected int _port; protected int _port;
protected long _requestMaxTime = 200; protected long _requestMaxTime = 200;
@Rule
public TestingDir _testDir = new TestingDir();
public void startServer(Class<? extends Filter> filter) throws Exception public void startServer(Class<? extends Filter> filter) throws Exception
{ {
_tester = new ServletTester("/ctx"); _tester = new ServletTester("/ctx");
DefaultSessionCache sessionCache = new DefaultSessionCache(_tester.getContext().getSessionHandler());
FileSessionDataStore fileStore = new FileSessionDataStore();
Path p = _testDir.getPathFile("sessions");
FS.ensureEmpty(p);
fileStore.setStoreDir(p.toFile());
sessionCache.setSessionDataStore(fileStore);
_tester.getContext().getSessionHandler().setSessionCache(sessionCache);
HttpURI uri = new HttpURI(_tester.createConnector(true)); HttpURI uri = new HttpURI(_tester.createConnector(true));
_host = uri.getHost(); _host = uri.getHost();
_port = uri.getPort(); _port = uri.getPort();

View File

@ -18,6 +18,9 @@
package org.eclipse.jetty.servlets; package org.eclipse.jetty.servlets;
import javax.servlet.ServletContext;
import org.eclipse.jetty.server.handler.ContextHandler;
import org.eclipse.jetty.servlets.DoSFilter.RateTracker; import org.eclipse.jetty.servlets.DoSFilter.RateTracker;
import org.hamcrest.Matchers; import org.hamcrest.Matchers;
import org.junit.Assert; import org.junit.Assert;
@ -36,7 +39,7 @@ public class DoSFilterTest extends AbstractDoSFilterTest
public void testRateIsRateExceeded() throws InterruptedException public void testRateIsRateExceeded() throws InterruptedException
{ {
DoSFilter doSFilter = new DoSFilter(); DoSFilter doSFilter = new DoSFilter();
doSFilter.setName("foo");
boolean exceeded = hitRateTracker(doSFilter,0); boolean exceeded = hitRateTracker(doSFilter,0);
Assert.assertTrue("Last hit should have exceeded",exceeded); Assert.assertTrue("Last hit should have exceeded",exceeded);
@ -49,6 +52,7 @@ public class DoSFilterTest extends AbstractDoSFilterTest
public void testWhitelist() throws Exception public void testWhitelist() throws Exception
{ {
DoSFilter filter = new DoSFilter(); DoSFilter filter = new DoSFilter();
filter.setName("foo");
filter.setWhitelist("192.168.0.1/32,10.0.0.0/8,4d8:0:a:1234:ABc:1F:b18:17,4d8:0:a:1234:ABc:1F:0:0/96"); filter.setWhitelist("192.168.0.1/32,10.0.0.0/8,4d8:0:a:1234:ABc:1F:b18:17,4d8:0:a:1234:ABc:1F:0:0/96");
Assert.assertTrue(filter.checkWhitelist("192.168.0.1")); Assert.assertTrue(filter.checkWhitelist("192.168.0.1"));
Assert.assertFalse(filter.checkWhitelist("192.168.0.2")); Assert.assertFalse(filter.checkWhitelist("192.168.0.2"));
@ -72,7 +76,8 @@ public class DoSFilterTest extends AbstractDoSFilterTest
private boolean hitRateTracker(DoSFilter doSFilter, int sleep) throws InterruptedException private boolean hitRateTracker(DoSFilter doSFilter, int sleep) throws InterruptedException
{ {
boolean exceeded = false; boolean exceeded = false;
RateTracker rateTracker = doSFilter.new RateTracker("test2",0,4); ServletContext context = new ContextHandler.StaticContext();
RateTracker rateTracker = new RateTracker(context, doSFilter.getName(), "test2",0,4);
for (int i = 0; i < 5; i++) for (int i = 0; i < 5; i++)
{ {