This commit is contained in:
Jan Bartel 2017-07-13 16:03:58 +02:00 committed by Joakim Erdfelt
parent e5f7fee279
commit e39c66d425
3 changed files with 140 additions and 21 deletions

View File

@ -147,6 +147,7 @@ public class DoSFilter implements Filter
private static final long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM = 30000L;
private static final long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM = 30000L;
static final String NAME = "name";
static final String MANAGED_ATTR_INIT_PARAM = "managedAttr";
static final String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
static final String DELAY_MS_INIT_PARAM = "delayMs";
@ -181,12 +182,14 @@ public class DoSFilter implements Filter
private volatile boolean _trackSessions;
private volatile boolean _remotePort;
private volatile boolean _enabled;
private volatile String _name;
private Semaphore _passes;
private volatile int _throttledRequests;
private volatile int _maxRequestsPerSec;
private Queue<AsyncContext>[] _queues;
private AsyncListener[] _listeners;
private Scheduler _scheduler;
private ServletContext _context;
public void init(FilterConfig filterConfig) throws ServletException
{
@ -263,11 +266,14 @@ public class DoSFilter implements Filter
parameter = filterConfig.getInitParameter(TOO_MANY_CODE);
setTooManyCode(parameter==null?429:Integer.parseInt(parameter));
_scheduler = startScheduler();
setName(filterConfig.getFilterName());
_context = filterConfig.getServletContext();
if (_context != null )
{
_context.setAttribute(filterConfig.getFilterName(), this);
}
ServletContext context = filterConfig.getServletContext();
if (context != null && Boolean.parseBoolean(filterConfig.getInitParameter(MANAGED_ATTR_INIT_PARAM)))
context.setAttribute(filterConfig.getFilterName(), this);
_scheduler = startScheduler();
}
protected Scheduler startScheduler() throws ServletException
@ -537,6 +543,11 @@ public class DoSFilter implements Filter
return USER_AUTH;
}
public void schedule (RateTracker tracker)
{
_scheduler.schedule(tracker, getMaxIdleTrackerMs(), TimeUnit.MILLISECONDS);
}
/**
* Return a request rate tracker associated with this connection; keeps
* track of this connection's request rate. If this is not the first request
@ -583,8 +594,9 @@ public class DoSFilter implements Filter
{
boolean allowed = checkWhitelist(request.getRemoteAddr());
int maxRequestsPerSec = getMaxRequestsPerSec();
tracker = allowed ? new FixedRateTracker(loadId, type, maxRequestsPerSec)
: new RateTracker(loadId, type, maxRequestsPerSec);
tracker = allowed ? new FixedRateTracker(_context, _name, loadId, type, maxRequestsPerSec)
: new RateTracker(_context,_name, loadId, type, maxRequestsPerSec);
tracker.setContext(_context);
RateTracker existing = _rateTrackers.putIfAbsent(loadId, tracker);
if (existing != null)
tracker = existing;
@ -604,6 +616,16 @@ public class DoSFilter implements Filter
return tracker;
}
public void addToRateTracker (RateTracker tracker)
{
_rateTrackers.put(tracker.getId(), tracker);
}
public void removeFromRateTracker (String id)
{
_rateTrackers.remove(id);
}
protected boolean checkWhitelist(String candidate)
{
for (String address : _whitelist)
@ -931,6 +953,25 @@ public class DoSFilter implements Filter
_maxIdleTrackerMs = value;
}
/**
* The unique name of the filter when there is more than
* one DosFilter instance.
*
* @return the name
*/
public String getName()
{
return _name;
}
/**
* @param name the name to set
*/
public void setName(String name)
{
_name = name;
}
/**
* Check flag to insert the DoSFilter headers into the response.
*
@ -1103,17 +1144,22 @@ public class DoSFilter implements Filter
* A RateTracker is associated with a connection, and stores request rate
* data.
*/
class RateTracker implements Runnable, HttpSessionBindingListener, HttpSessionActivationListener, Serializable
static class RateTracker implements Runnable, HttpSessionBindingListener, HttpSessionActivationListener, Serializable
{
private static final long serialVersionUID = 3534663738034577872L;
protected final String _filterName;
protected transient ServletContext _context;
protected final String _id;
protected final int _type;
protected final long[] _timestamps;
protected int _next;
public RateTracker(String id, int type, int maxRequestsPerSecond)
public RateTracker(ServletContext context, String filterName, String id, int type, int maxRequestsPerSecond)
{
_context = context;
_filterName = filterName;
_id = id;
_type = type;
_timestamps = new long[maxRequestsPerSecond];
@ -1151,40 +1197,87 @@ public class DoSFilter implements Filter
{
if (LOG.isDebugEnabled())
LOG.debug("Value bound: {}", getId());
_context = event.getSession().getServletContext();
}
public void valueUnbound(HttpSessionBindingEvent event)
{
//take the tracker out of the list of trackers
_rateTrackers.remove(_id);
if (LOG.isDebugEnabled())
LOG.debug("Tracker removed: {}", getId());
DoSFilter filter = (DoSFilter)event.getSession().getServletContext().getAttribute(_filterName);
removeFromRateTrackers(filter, _id);
_context = null;
}
public void sessionWillPassivate(HttpSessionEvent se)
{
//take the tracker of the list of trackers (if its still there)
_rateTrackers.remove(_id);
DoSFilter filter = (DoSFilter)se.getSession().getServletContext().getAttribute(_filterName);
removeFromRateTrackers(filter, _id);
_context = null;
}
public void sessionDidActivate(HttpSessionEvent se)
{
RateTracker tracker = (RateTracker)se.getSession().getAttribute(__TRACKER);
if (tracker!=null)
_rateTrackers.put(tracker.getId(),tracker);
ServletContext context = se.getSession().getServletContext();
tracker.setContext(context);
DoSFilter filter = (DoSFilter)context.getAttribute(_filterName);
if (filter == null)
{
LOG.info("No filter {} for rate tracker {}", _filterName, tracker);
return;
}
addToRateTrackers(filter, tracker);
}
public void setContext (ServletContext context)
{
_context = context;
}
protected void removeFromRateTrackers (DoSFilter filter, String id)
{
if (filter == null)
return;
filter.removeFromRateTracker(id);
if (LOG.isDebugEnabled())
LOG.debug("Tracker removed: {}", getId());
}
protected void addToRateTrackers (DoSFilter filter, RateTracker tracker)
{
if (filter == null)
return;
filter.addToRateTracker(tracker);
}
@Override
public void run()
{
if (_context == null)
{
LOG.warn("Unknkown context for rate tracker {}", this);
return;
}
int latestIndex = _next == 0 ? (_timestamps.length - 1) : (_next - 1);
long last = _timestamps[latestIndex];
boolean hasRecentRequest = last != 0 && (System.currentTimeMillis() - last) < 1000L;
DoSFilter filter = (DoSFilter)_context.getAttribute(_filterName);
if (hasRecentRequest)
_scheduler.schedule(this, getMaxIdleTrackerMs(), TimeUnit.MILLISECONDS);
{
if (filter != null)
filter.schedule(this);
else
_rateTrackers.remove(_id);
LOG.warn("No filter {}", _filterName);
}
else
removeFromRateTrackers(filter, _id);
}
@Override
@ -1196,9 +1289,9 @@ public class DoSFilter implements Filter
class FixedRateTracker extends RateTracker
{
public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
public FixedRateTracker(ServletContext context, String filterName, String id, int type, int numRecentRequestsTracked)
{
super(id, type, numRecentRequestsTracked);
super(context, filterName, id, type, numRecentRequestsTracked);
}
@Override

View File

@ -18,11 +18,13 @@
package org.eclipse.jetty.servlets;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.EnumSet;
import javax.servlet.DispatcherType;
@ -34,12 +36,17 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.eclipse.jetty.http.HttpURI;
import org.eclipse.jetty.server.session.DefaultSessionCache;
import org.eclipse.jetty.server.session.FileSessionDataStore;
import org.eclipse.jetty.servlet.FilterHolder;
import org.eclipse.jetty.servlet.ServletTester;
import org.eclipse.jetty.toolchain.test.FS;
import org.eclipse.jetty.toolchain.test.TestingDir;
import org.eclipse.jetty.util.IO;
import org.hamcrest.Matchers;
import org.junit.After;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
public abstract class AbstractDoSFilterTest
@ -49,9 +56,23 @@ public abstract class AbstractDoSFilterTest
protected int _port;
protected long _requestMaxTime = 200;
@Rule
public TestingDir _testDir = new TestingDir();
public void startServer(Class<? extends Filter> filter) throws Exception
{
_tester = new ServletTester("/ctx");
DefaultSessionCache sessionCache = new DefaultSessionCache(_tester.getContext().getSessionHandler());
FileSessionDataStore fileStore = new FileSessionDataStore();
Path p = _testDir.getPathFile("sessions");
FS.ensureEmpty(p);
fileStore.setStoreDir(p.toFile());
sessionCache.setSessionDataStore(fileStore);
_tester.getContext().getSessionHandler().setSessionCache(sessionCache);
HttpURI uri = new HttpURI(_tester.createConnector(true));
_host = uri.getHost();
_port = uri.getPort();

View File

@ -18,6 +18,9 @@
package org.eclipse.jetty.servlets;
import javax.servlet.ServletContext;
import org.eclipse.jetty.server.handler.ContextHandler;
import org.eclipse.jetty.servlets.DoSFilter.RateTracker;
import org.hamcrest.Matchers;
import org.junit.Assert;
@ -36,7 +39,7 @@ public class DoSFilterTest extends AbstractDoSFilterTest
public void testRateIsRateExceeded() throws InterruptedException
{
DoSFilter doSFilter = new DoSFilter();
doSFilter.setName("foo");
boolean exceeded = hitRateTracker(doSFilter,0);
Assert.assertTrue("Last hit should have exceeded",exceeded);
@ -49,6 +52,7 @@ public class DoSFilterTest extends AbstractDoSFilterTest
public void testWhitelist() throws Exception
{
DoSFilter filter = new DoSFilter();
filter.setName("foo");
filter.setWhitelist("192.168.0.1/32,10.0.0.0/8,4d8:0:a:1234:ABc:1F:b18:17,4d8:0:a:1234:ABc:1F:0:0/96");
Assert.assertTrue(filter.checkWhitelist("192.168.0.1"));
Assert.assertFalse(filter.checkWhitelist("192.168.0.2"));
@ -72,7 +76,8 @@ public class DoSFilterTest extends AbstractDoSFilterTest
private boolean hitRateTracker(DoSFilter doSFilter, int sleep) throws InterruptedException
{
boolean exceeded = false;
RateTracker rateTracker = doSFilter.new RateTracker("test2",0,4);
ServletContext context = new ContextHandler.StaticContext();
RateTracker rateTracker = new RateTracker(context, doSFilter.getName(), "test2",0,4);
for (int i = 0; i < 5; i++)
{