401908 - Enhance DosFilter to allow dynamic configuration of attributes.

This commit is contained in:
Simone Bordet 2013-02-27 16:23:17 +01:00
parent 0e9f74ad29
commit 90bab0eb66
5 changed files with 523 additions and 367 deletions

View File

@ -84,6 +84,12 @@
<artifactId>javax.servlet</artifactId> <artifactId>javax.servlet</artifactId>
<scope>provided</scope> <scope>provided</scope>
</dependency> </dependency>
<dependency>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-jmx</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>org.eclipse.jetty</groupId> <groupId>org.eclipse.jetty</groupId>
<artifactId>test-jetty-servlet</artifactId> <artifactId>test-jetty-servlet</artifactId>

View File

@ -19,16 +19,17 @@
package org.eclipse.jetty.servlets; package org.eclipse.jetty.servlets;
import java.io.IOException; import java.io.IOException;
import java.io.Serializable; import java.util.ArrayList;
import java.util.HashSet; import java.util.Iterator;
import java.util.Map; import java.util.List;
import java.util.Queue; import java.util.Queue;
import java.util.StringTokenizer;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.Semaphore; import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.servlet.Filter; import javax.servlet.Filter;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
import javax.servlet.FilterConfig; import javax.servlet.FilterConfig;
@ -54,9 +55,9 @@ import org.eclipse.jetty.util.thread.Timeout;
/** /**
* Denial of Service filter * Denial of Service filter
* * <p/>
* <p> * <p>
* This filter is based on the {@link QoSFilter}. it is useful for limiting * This filter is useful for limiting
* exposure to abuse from request flooding, whether malicious, or as a result of * exposure to abuse from request flooding, whether malicious, or as a result of
* a misconfigured client. * a misconfigured client.
* <p> * <p>
@ -73,46 +74,46 @@ import org.eclipse.jetty.util.thread.Timeout;
* implemented, in order to uniquely identify authenticated users. * implemented, in order to uniquely identify authenticated users.
* <p> * <p>
* The following init parameters control the behavior of the filter:<dl> * The following init parameters control the behavior of the filter:<dl>
* * <p/>
* <dt>maxRequestsPerSec</dt> * <dt>maxRequestsPerSec</dt>
* <dd>the maximum number of requests from a connection per * <dd>the maximum number of requests from a connection per
* second. Requests in excess of this are first delayed, * second. Requests in excess of this are first delayed,
* then throttled.</dd> * then throttled.</dd>
* * <p/>
* <dt>delayMs</dt> * <dt>delayMs</dt>
* <dd>is the delay given to all requests over the rate limit, * <dd>is the delay given to all requests over the rate limit,
* before they are considered at all. -1 means just reject request, * before they are considered at all. -1 means just reject request,
* 0 means no delay, otherwise it is the delay.</dd> * 0 means no delay, otherwise it is the delay.</dd>
* * <p/>
* <dt>maxWaitMs</dt> * <dt>maxWaitMs</dt>
* <dd>how long to blocking wait for the throttle semaphore.</dd> * <dd>how long to blocking wait for the throttle semaphore.</dd>
* * <p/>
* <dt>throttledRequests</dt> * <dt>throttledRequests</dt>
* <dd>is the number of requests over the rate limit able to be * <dd>is the number of requests over the rate limit able to be
* considered at once.</dd> * considered at once.</dd>
* * <p/>
* <dt>throttleMs</dt> * <dt>throttleMs</dt>
* <dd>how long to async wait for semaphore.</dd> * <dd>how long to async wait for semaphore.</dd>
* * <p/>
* <dt>maxRequestMs</dt> * <dt>maxRequestMs</dt>
* <dd>how long to allow this request to run.</dd> * <dd>how long to allow this request to run.</dd>
* * <p/>
* <dt>maxIdleTrackerMs</dt> * <dt>maxIdleTrackerMs</dt>
* <dd>how long to keep track of request rates for a connection, * <dd>how long to keep track of request rates for a connection,
* before deciding that the user has gone away, and discarding it</dd> * before deciding that the user has gone away, and discarding it</dd>
* * <p/>
* <dt>insertHeaders</dt> * <dt>insertHeaders</dt>
* <dd>if true , insert the DoSFilter headers into the response. Defaults to true.</dd> * <dd>if true , insert the DoSFilter headers into the response. Defaults to true.</dd>
* * <p/>
* <dt>trackSessions</dt> * <dt>trackSessions</dt>
* <dd>if true, usage rate is tracked by session if a session exists. Defaults to true.</dd> * <dd>if true, usage rate is tracked by session if a session exists. Defaults to true.</dd>
* * <p/>
* <dt>remotePort</dt> * <dt>remotePort</dt>
* <dd>if true and session tracking is not used, then rate is tracked by IP+port (effectively connection). Defaults to false.</dd> * <dd>if true and session tracking is not used, then rate is tracked by IP+port (effectively connection). Defaults to false.</dd>
* * <p/>
* <dt>ipWhitelist</dt> * <dt>ipWhitelist</dt>
* <dd>a comma-separated list of IP addresses that will not be rate limited</dd> * <dd>a comma-separated list of IP addresses that will not be rate limited</dd>
* * <p/>
* <dt>managedAttr</dt> * <dt>managedAttr</dt>
* <dd>if set to true, then this servlet is set as a {@link ServletContext} attribute with the * <dd>if set to true, then this servlet is set as a {@link ServletContext} attribute with the
* filter name as the attribute name. This allows context external mechanism (eg JMX via {@link ContextHandler#MANAGED_ATTRIBUTES}) to * filter name as the attribute name. This allows context external mechanism (eg JMX via {@link ContextHandler#MANAGED_ATTRIBUTES}) to
@ -120,64 +121,62 @@ import org.eclipse.jetty.util.thread.Timeout;
* </dl> * </dl>
* </p> * </p>
*/ */
public class DoSFilter implements Filter public class DoSFilter implements Filter
{ {
private static final Logger LOG = Log.getLogger(DoSFilter.class); private static final Logger LOG = Log.getLogger(DoSFilter.class);
final static String __TRACKER = "DoSFilter.Tracker"; private static final Pattern IP_PATTERN = Pattern.compile("(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})");
final static String __THROTTLED = "DoSFilter.Throttled"; private static final Pattern CIDR_PATTERN = Pattern.compile(IP_PATTERN + "/(\\d{1,2})");
final static int __DEFAULT_MAX_REQUESTS_PER_SEC = 25; private static final String __TRACKER = "DoSFilter.Tracker";
final static int __DEFAULT_DELAY_MS = 100; private static final String __THROTTLED = "DoSFilter.Throttled";
final static int __DEFAULT_THROTTLE = 5;
final static int __DEFAULT_WAIT_MS=50;
final static long __DEFAULT_THROTTLE_MS = 30000L;
final static long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM=30000L;
final static long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM=30000L;
final static String MANAGED_ATTR_INIT_PARAM="managedAttr"; private static final int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
final static String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec"; private static final int __DEFAULT_DELAY_MS = 100;
final static String DELAY_MS_INIT_PARAM = "delayMs"; private static final int __DEFAULT_THROTTLE = 5;
final static String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests"; private static final int __DEFAULT_MAX_WAIT_MS = 50;
final static String MAX_WAIT_INIT_PARAM="maxWaitMs"; private static final long __DEFAULT_THROTTLE_MS = 30000L;
final static String THROTTLE_MS_INIT_PARAM = "throttleMs"; private static final long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM = 30000L;
final static String MAX_REQUEST_MS_INIT_PARAM="maxRequestMs"; private static final long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM = 30000L;
final static String MAX_IDLE_TRACKER_MS_INIT_PARAM="maxIdleTrackerMs";
final static String INSERT_HEADERS_INIT_PARAM="insertHeaders";
final static String TRACK_SESSIONS_INIT_PARAM="trackSessions";
final static String REMOTE_PORT_INIT_PARAM="remotePort";
final static String IP_WHITELIST_INIT_PARAM="ipWhitelist";
final static int USER_AUTH = 2; static final String MANAGED_ATTR_INIT_PARAM = "managedAttr";
final static int USER_SESSION = 2; static final String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
final static int USER_IP = 1; static final String DELAY_MS_INIT_PARAM = "delayMs";
final static int USER_UNKNOWN = 0; static final String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
static final String MAX_WAIT_INIT_PARAM = "maxWaitMs";
static final String THROTTLE_MS_INIT_PARAM = "throttleMs";
static final String MAX_REQUEST_MS_INIT_PARAM = "maxRequestMs";
static final String MAX_IDLE_TRACKER_MS_INIT_PARAM = "maxIdleTrackerMs";
static final String INSERT_HEADERS_INIT_PARAM = "insertHeaders";
static final String TRACK_SESSIONS_INIT_PARAM = "trackSessions";
static final String REMOTE_PORT_INIT_PARAM = "remotePort";
static final String IP_WHITELIST_INIT_PARAM = "ipWhitelist";
static final String ENABLED_INIT_PARAM = "enabled";
ServletContext _context; private static final int USER_AUTH = 2;
private static final int USER_SESSION = 2;
protected String _name; private static final int USER_IP = 1;
protected long _delayMs; private static final int USER_UNKNOWN = 0;
protected long _throttleMs;
protected long _maxWaitMs;
protected long _maxRequestMs;
protected long _maxIdleTrackerMs;
protected boolean _insertHeaders;
protected boolean _trackSessions;
protected boolean _remotePort;
protected int _throttledRequests;
protected Semaphore _passes;
protected Queue<Continuation>[] _queue;
protected ContinuationListener[] _listener;
protected int _maxRequestsPerSec;
protected final ConcurrentHashMap<String, RateTracker> _rateTrackers=new ConcurrentHashMap<String, RateTracker>();
protected String _whitelistStr;
private final HashSet<String> _whitelist = new HashSet<String>();
private ServletContext _context;
private volatile long _delayMs;
private volatile long _throttleMs;
private volatile long _maxWaitMs;
private volatile long _maxRequestMs;
private volatile long _maxIdleTrackerMs;
private volatile boolean _insertHeaders;
private volatile boolean _trackSessions;
private volatile boolean _remotePort;
private volatile boolean _enabled;
private Semaphore _passes;
private volatile int _throttledRequests;
private volatile int _maxRequestsPerSec;
private Queue<Continuation>[] _queue;
private ContinuationListener[] _listeners;
private final ConcurrentHashMap<String, RateTracker> _rateTrackers = new ConcurrentHashMap<String, RateTracker>();
private final List<String> _whitelist = new CopyOnWriteArrayList<String>();
private final Timeout _requestTimeoutQ = new Timeout(); private final Timeout _requestTimeoutQ = new Timeout();
private final Timeout _trackerTimeoutQ = new Timeout(); private final Timeout _trackerTimeoutQ = new Timeout();
private Thread _timerThread; private Thread _timerThread;
private volatile boolean _running; private volatile boolean _running;
@ -186,13 +185,13 @@ public class DoSFilter implements Filter
_context = filterConfig.getServletContext(); _context = filterConfig.getServletContext();
_queue = new Queue[getMaxPriority() + 1]; _queue = new Queue[getMaxPriority() + 1];
_listener = new ContinuationListener[getMaxPriority() + 1]; _listeners = new ContinuationListener[getMaxPriority() + 1];
for (int p = 0; p < _queue.length; p++) for (int p = 0; p < _queue.length; p++)
{ {
_queue[p] = new ConcurrentLinkedQueue<Continuation>(); _queue[p] = new ConcurrentLinkedQueue<Continuation>();
final int priority = p; final int priority = p;
_listener[p] = new ContinuationListener() _listeners[p] = new ContinuationListener()
{ {
public void onComplete(Continuation continuation) public void onComplete(Continuation continuation)
{ {
@ -207,55 +206,65 @@ public class DoSFilter implements Filter
_rateTrackers.clear(); _rateTrackers.clear();
int baseRateLimit = __DEFAULT_MAX_REQUESTS_PER_SEC; int maxRequests = __DEFAULT_MAX_REQUESTS_PER_SEC;
if (filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM) != null) String parameter = filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM);
baseRateLimit = Integer.parseInt(filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM)); if (parameter != null)
_maxRequestsPerSec = baseRateLimit; maxRequests = Integer.parseInt(parameter);
setMaxRequestsPerSec(maxRequests);
long delay = __DEFAULT_DELAY_MS; long delay = __DEFAULT_DELAY_MS;
if (filterConfig.getInitParameter(DELAY_MS_INIT_PARAM) != null) parameter = filterConfig.getInitParameter(DELAY_MS_INIT_PARAM);
delay = Integer.parseInt(filterConfig.getInitParameter(DELAY_MS_INIT_PARAM)); if (parameter != null)
_delayMs = delay; delay = Long.parseLong(parameter);
setDelayMs(delay);
int throttledRequests = __DEFAULT_THROTTLE; int throttledRequests = __DEFAULT_THROTTLE;
if (filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM) != null) parameter = filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM);
throttledRequests = Integer.parseInt(filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM)); if (parameter != null)
_passes = new Semaphore(throttledRequests,true); throttledRequests = Integer.parseInt(parameter);
_throttledRequests = throttledRequests; setThrottledRequests(throttledRequests);
long wait = __DEFAULT_WAIT_MS; long maxWait = __DEFAULT_MAX_WAIT_MS;
if (filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM) != null) parameter = filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM);
wait = Integer.parseInt(filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM)); if (parameter != null)
_maxWaitMs = wait; maxWait = Long.parseLong(parameter);
setMaxWaitMs(maxWait);
long suspend = __DEFAULT_THROTTLE_MS; long throttle = __DEFAULT_THROTTLE_MS;
if (filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM) != null) parameter = filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM);
suspend = Integer.parseInt(filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM)); if (parameter != null)
_throttleMs = suspend; throttle = Long.parseLong(parameter);
setThrottleMs(throttle);
long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM; long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
if (filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM) != null ) parameter = filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM);
maxRequestMs = Long.parseLong(filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM)); if (parameter != null)
_maxRequestMs = maxRequestMs; maxRequestMs = Long.parseLong(parameter);
setMaxRequestMs(maxRequestMs);
long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM; long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
if (filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM) != null ) parameter = filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM);
maxIdleTrackerMs = Long.parseLong(filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM)); if (parameter != null)
_maxIdleTrackerMs = maxIdleTrackerMs; maxIdleTrackerMs = Long.parseLong(parameter);
setMaxIdleTrackerMs(maxIdleTrackerMs);
_whitelistStr = ""; String whiteList = "";
if (filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM) !=null ) parameter = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
_whitelistStr = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM); if (parameter != null)
initWhitelist(); whiteList = parameter;
setWhitelist(whiteList);
String tmp = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM); parameter = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
_insertHeaders = tmp==null || Boolean.parseBoolean(tmp); setInsertHeaders(parameter == null || Boolean.parseBoolean(parameter));
tmp = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM); parameter = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
_trackSessions = tmp==null || Boolean.parseBoolean(tmp); setTrackSessions(parameter == null || Boolean.parseBoolean(parameter));
tmp = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM); parameter = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
_remotePort = tmp!=null&& Boolean.parseBoolean(tmp); setRemotePort(parameter != null && Boolean.parseBoolean(parameter));
parameter = filterConfig.getInitParameter(ENABLED_INIT_PARAM);
setEnabled(parameter == null || Boolean.parseBoolean(parameter));
_requestTimeoutQ.setNow(); _requestTimeoutQ.setNow();
_requestTimeoutQ.setDuration(_maxRequestMs); _requestTimeoutQ.setDuration(_maxRequestMs);
@ -272,17 +281,10 @@ public class DoSFilter implements Filter
{ {
while (_running) while (_running)
{ {
long now; long now = _requestTimeoutQ.setNow();
synchronized (_requestTimeoutQ)
{
now = _requestTimeoutQ.setNow();
_requestTimeoutQ.tick(); _requestTimeoutQ.tick();
}
synchronized (_trackerTimeoutQ)
{
_trackerTimeoutQ.setNow(now); _trackerTimeoutQ.setNow(now);
_trackerTimeoutQ.tick(); _trackerTimeoutQ.tick();
}
try try
{ {
Thread.sleep(100); Thread.sleep(100);
@ -295,7 +297,7 @@ public class DoSFilter implements Filter
} }
finally finally
{ {
LOG.info("DoSFilter timer exited"); LOG.debug("DoSFilter timer exited");
} }
} }
}); });
@ -305,11 +307,18 @@ public class DoSFilter implements Filter
_context.setAttribute(filterConfig.getFilterName(), this); _context.setAttribute(filterConfig.getFilterName(), this);
} }
public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) throws IOException, ServletException
public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterchain) throws IOException, ServletException
{ {
final HttpServletRequest srequest = (HttpServletRequest)request; doFilter((HttpServletRequest)request, (HttpServletResponse)response, filterChain);
final HttpServletResponse sresponse = (HttpServletResponse)response; }
protected void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws IOException, ServletException
{
if (!isEnabled())
{
filterChain.doFilter(request, response);
return;
}
final long now = _requestTimeoutQ.getNow(); final long now = _requestTimeoutQ.getNow();
@ -329,22 +338,24 @@ public class DoSFilter implements Filter
// pass it through if we are not currently over the rate limit // pass it through if we are not currently over the rate limit
if (!overRateLimit) if (!overRateLimit)
{ {
doFilterChain(filterchain,srequest,sresponse); doFilterChain(filterChain, request, response);
return; return;
} }
// We are over the limit. // We are over the limit.
LOG.warn("DOS ALERT: ip="+srequest.getRemoteAddr()+",session="+srequest.getRequestedSessionId()+",user="+srequest.getUserPrincipal()); LOG.warn("DOS ALERT: ip=" + request.getRemoteAddr() + ",session=" + request.getRequestedSessionId() + ",user=" + request.getUserPrincipal());
// So either reject it, delay it or throttle it // So either reject it, delay it or throttle it
switch((int)_delayMs) long delayMs = getDelayMs();
boolean insertHeaders = isInsertHeaders();
switch ((int)delayMs)
{ {
case -1: case -1:
{ {
// Reject this request // Reject this request
if (_insertHeaders) if (insertHeaders)
((HttpServletResponse)response).addHeader("DoSFilter","unavailable"); response.addHeader("DoSFilter", "unavailable");
((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE); response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
return; return;
} }
case 0: case 0:
@ -356,12 +367,12 @@ public class DoSFilter implements Filter
default: default:
{ {
// insert a delay before throttling the request // insert a delay before throttling the request
if (_insertHeaders) if (insertHeaders)
((HttpServletResponse)response).addHeader("DoSFilter","delayed"); response.addHeader("DoSFilter", "delayed");
Continuation continuation = ContinuationSupport.getContinuation(request); Continuation continuation = ContinuationSupport.getContinuation(request);
request.setAttribute(__TRACKER, tracker); request.setAttribute(__TRACKER, tracker);
if (_delayMs > 0) if (delayMs > 0)
continuation.setTimeout(_delayMs); continuation.setTimeout(delayMs);
continuation.suspend(); continuation.suspend();
return; return;
} }
@ -373,7 +384,7 @@ public class DoSFilter implements Filter
try try
{ {
// check if we can afford to accept another request at this time // check if we can afford to accept another request at this time
accepted = _passes.tryAcquire(_maxWaitMs,TimeUnit.MILLISECONDS); accepted = _passes.tryAcquire(getMaxWaitMs(), TimeUnit.MILLISECONDS);
if (!accepted) if (!accepted)
{ {
@ -381,17 +392,18 @@ public class DoSFilter implements Filter
final Continuation continuation = ContinuationSupport.getContinuation(request); final Continuation continuation = ContinuationSupport.getContinuation(request);
Boolean throttled = (Boolean)request.getAttribute(__THROTTLED); Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
if (throttled!=Boolean.TRUE && _throttleMs>0) long throttleMs = getThrottleMs();
if (throttled != Boolean.TRUE && throttleMs > 0)
{ {
int priority = getPriority(request, tracker); int priority = getPriority(request, tracker);
request.setAttribute(__THROTTLED, Boolean.TRUE); request.setAttribute(__THROTTLED, Boolean.TRUE);
if (_insertHeaders) if (isInsertHeaders())
((HttpServletResponse)response).addHeader("DoSFilter","throttled"); response.addHeader("DoSFilter", "throttled");
if (_throttleMs > 0) if (throttleMs > 0)
continuation.setTimeout(_throttleMs); continuation.setTimeout(throttleMs);
continuation.suspend(); continuation.suspend();
continuation.addContinuationListener(_listener[priority]); continuation.addContinuationListener(_listeners[priority]);
_queue[priority].add(continuation); _queue[priority].add(continuation);
return; return;
} }
@ -407,19 +419,19 @@ public class DoSFilter implements Filter
// if we were accepted (either immediately or after throttle) // if we were accepted (either immediately or after throttle)
if (accepted) if (accepted)
// call the chain // call the chain
doFilterChain(filterchain,srequest,sresponse); doFilterChain(filterChain, request, response);
else else
{ {
// fail the request // fail the request
if (_insertHeaders) if (isInsertHeaders())
((HttpServletResponse)response).addHeader("DoSFilter","unavailable"); response.addHeader("DoSFilter", "unavailable");
((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE); response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
} }
} }
catch (InterruptedException e) catch (InterruptedException e)
{ {
_context.log("DoS", e); _context.log("DoS", e);
((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE); response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
} }
finally finally
{ {
@ -440,15 +452,7 @@ public class DoSFilter implements Filter
} }
} }
/** protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response) throws IOException, ServletException
* @param chain
* @param request
* @param response
* @throws IOException
* @throws ServletException
*/
protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response)
throws IOException, ServletException
{ {
final Thread thread = Thread.currentThread(); final Thread thread = Thread.currentThread();
@ -461,25 +465,20 @@ public class DoSFilter implements Filter
}; };
try try
{
synchronized (_requestTimeoutQ)
{ {
_requestTimeoutQ.schedule(requestTimeout); _requestTimeoutQ.schedule(requestTimeout);
}
chain.doFilter(request, response); chain.doFilter(request, response);
} }
finally finally
{
synchronized (_requestTimeoutQ)
{ {
requestTimeout.cancel(); requestTimeout.cancel();
} }
} }
}
/** /**
* Takes drastic measures to return this response and stop this thread. * Takes drastic measures to return this response and stop this thread.
* Due to the way the connection is interrupted, may return mixed up headers. * Due to the way the connection is interrupted, may return mixed up headers.
*
* @param request current request * @param request current request
* @param response current response, which must be stopped * @param response current response, which must be stopped
* @param thread the handling thread * @param thread the handling thread
@ -514,11 +513,11 @@ public class DoSFilter implements Filter
/** /**
* Get priority for this request, based on user type * Get priority for this request, based on user type
* *
* @param request * @param request the current request
* @param tracker * @param tracker the rate tracker for this request
* @return priority * @return the priority for this request
*/ */
protected int getPriority(ServletRequest request, RateTracker tracker) protected int getPriority(HttpServletRequest request, RateTracker tracker)
{ {
if (extractUserId(request) != null) if (extractUserId(request) != null)
return USER_AUTH; return USER_AUTH;
@ -540,21 +539,20 @@ public class DoSFilter implements Filter
* track of this connection's request rate. If this is not the first request * track of this connection's request rate. If this is not the first request
* from this connection, return the existing object with the stored stats. * from this connection, return the existing object with the stored stats.
* If it is the first request, then create a new request tracker. * If it is the first request, then create a new request tracker.
* * <p/>
* Assumes that each connection has an identifying characteristic, and goes * Assumes that each connection has an identifying characteristic, and goes
* through them in order, taking the first that matches: user id (logged * through them in order, taking the first that matches: user id (logged
* in), session id, client IP address. Unidentifiable connections are lumped * in), session id, client IP address. Unidentifiable connections are lumped
* into one. * into one.
* * <p/>
* When a session expires, its rate tracker is automatically deleted. * When a session expires, its rate tracker is automatically deleted.
* *
* @param request * @param request the current request
* @return the request rate tracker for the current connection * @return the request rate tracker for the current connection
*/ */
public RateTracker getRateTracker(ServletRequest request) public RateTracker getRateTracker(ServletRequest request)
{ {
HttpServletRequest srequest = (HttpServletRequest)request; HttpSession session = ((HttpServletRequest)request).getSession(false);
HttpSession session=srequest.getSession(false);
String loadId = extractUserId(request); String loadId = extractUserId(request);
final int type; final int type;
@ -580,48 +578,79 @@ public class DoSFilter implements Filter
if (tracker == null) if (tracker == null)
{ {
RateTracker t; boolean allowed = checkWhitelist(_whitelist, request.getRemoteAddr());
if (_whitelist.contains(request.getRemoteAddr())) tracker = allowed ? new FixedRateTracker(loadId, type, _maxRequestsPerSec)
{ : new RateTracker(loadId, type, _maxRequestsPerSec);
t = new FixedRateTracker(loadId,type,_maxRequestsPerSec); RateTracker existing = _rateTrackers.putIfAbsent(loadId, tracker);
} if (existing != null)
else tracker = existing;
{
t = new RateTracker(loadId,type,_maxRequestsPerSec);
}
tracker=_rateTrackers.putIfAbsent(loadId,t);
if (tracker==null)
tracker=t;
if (type == USER_IP) if (type == USER_IP)
{ {
// USER_IP expiration from _rateTrackers is handled by the _trackerTimeoutQ // USER_IP expiration from _rateTrackers is handled by the _trackerTimeoutQ
synchronized (_trackerTimeoutQ)
{
_trackerTimeoutQ.schedule(tracker); _trackerTimeoutQ.schedule(tracker);
} }
}
else if (session != null) else if (session != null)
{
// USER_SESSION expiration from _rateTrackers are handled by the HttpSessionBindingListener // USER_SESSION expiration from _rateTrackers are handled by the HttpSessionBindingListener
session.setAttribute(__TRACKER, tracker); session.setAttribute(__TRACKER, tracker);
} }
}
return tracker; return tracker;
} }
protected boolean checkWhitelist(List<String> whitelist, String candidate)
{
for (String address : whitelist)
{
if (address.contains("/"))
{
if (subnetMatch(address, candidate))
return true;
}
else
{
if (address.equals(candidate))
return true;
}
}
return false;
}
protected boolean subnetMatch(String subnetAddress, String candidate)
{
Matcher matcher = CIDR_PATTERN.matcher(subnetAddress);
int subnet = intFromAddress(matcher);
int prefix = Integer.parseInt(matcher.group(5));
// Sets the most significant prefix bits to 1
// If prefix == 8 => 11111111_00000000_00000000_00000000
int mask = ~((1 << (32 - prefix)) - 1);
int ip = intFromAddress(IP_PATTERN.matcher(candidate));
return (ip & mask) == (subnet & mask);
}
private int intFromAddress(Matcher matcher)
{
int result = 0;
if (matcher.matches())
{
for (int i = 0; i < 4; ++i)
{
int b = Integer.parseInt(matcher.group(i + 1));
result |= b << 8 * (3 - i);
}
return result;
}
throw new IllegalStateException();
}
public void destroy() public void destroy()
{ {
_running = false; _running = false;
_timerThread.interrupt(); _timerThread.interrupt();
synchronized (_requestTimeoutQ)
{
_requestTimeoutQ.cancelAll(); _requestTimeoutQ.cancelAll();
}
synchronized (_trackerTimeoutQ)
{
_trackerTimeoutQ.cancelAll(); _trackerTimeoutQ.cancelAll();
}
_rateTrackers.clear(); _rateTrackers.clear();
_whitelist.clear(); _whitelist.clear();
} }
@ -630,7 +659,7 @@ public class DoSFilter implements Filter
* Returns the user id, used to track this connection. * Returns the user id, used to track this connection.
* This SHOULD be overridden by subclasses. * This SHOULD be overridden by subclasses.
* *
* @param request * @param request the current request
* @return a unique user id, if logged in; otherwise null. * @return a unique user id, if logged in; otherwise null.
*/ */
protected String extractUserId(ServletRequest request) protected String extractUserId(ServletRequest request)
@ -638,21 +667,6 @@ public class DoSFilter implements Filter
return null; return null;
} }
/* ------------------------------------------------------------ */
/**
* Initialize the IP address whitelist
*/
protected void initWhitelist()
{
_whitelist.clear();
StringTokenizer tokenizer = new StringTokenizer(_whitelistStr, ",");
while (tokenizer.hasMoreTokens())
_whitelist.add(tokenizer.nextToken().trim());
LOG.info("Whitelisted IP addresses: {}", _whitelist.toString());
}
/* ------------------------------------------------------------ */
/** /**
* Get maximum number of requests from a connection per * Get maximum number of requests from a connection per
* second. Requests in excess of this are first delayed, * second. Requests in excess of this are first delayed,
@ -665,7 +679,6 @@ public class DoSFilter implements Filter
return _maxRequestsPerSec; return _maxRequestsPerSec;
} }
/* ------------------------------------------------------------ */
/** /**
* Get maximum number of requests from a connection per * Get maximum number of requests from a connection per
* second. Requests in excess of this are first delayed, * second. Requests in excess of this are first delayed,
@ -678,7 +691,6 @@ public class DoSFilter implements Filter
_maxRequestsPerSec = value; _maxRequestsPerSec = value;
} }
/* ------------------------------------------------------------ */
/** /**
* Get delay (in milliseconds) that is applied to all requests * Get delay (in milliseconds) that is applied to all requests
* over the rate limit, before they are considered at all. * over the rate limit, before they are considered at all.
@ -688,7 +700,6 @@ public class DoSFilter implements Filter
return _delayMs; return _delayMs;
} }
/* ------------------------------------------------------------ */
/** /**
* Set delay (in milliseconds) that is applied to all requests * Set delay (in milliseconds) that is applied to all requests
* over the rate limit, before they are considered at all. * over the rate limit, before they are considered at all.
@ -700,7 +711,6 @@ public class DoSFilter implements Filter
_delayMs = value; _delayMs = value;
} }
/* ------------------------------------------------------------ */
/** /**
* Get maximum amount of time (in milliseconds) the filter will * Get maximum amount of time (in milliseconds) the filter will
* blocking wait for the throttle semaphore. * blocking wait for the throttle semaphore.
@ -712,7 +722,6 @@ public class DoSFilter implements Filter
return _maxWaitMs; return _maxWaitMs;
} }
/* ------------------------------------------------------------ */
/** /**
* Set maximum amount of time (in milliseconds) the filter will * Set maximum amount of time (in milliseconds) the filter will
* blocking wait for the throttle semaphore. * blocking wait for the throttle semaphore.
@ -724,7 +733,6 @@ public class DoSFilter implements Filter
_maxWaitMs = value; _maxWaitMs = value;
} }
/* ------------------------------------------------------------ */
/** /**
* Get number of requests over the rate limit able to be * Get number of requests over the rate limit able to be
* considered at once. * considered at once.
@ -736,7 +744,6 @@ public class DoSFilter implements Filter
return _throttledRequests; return _throttledRequests;
} }
/* ------------------------------------------------------------ */
/** /**
* Set number of requests over the rate limit able to be * Set number of requests over the rate limit able to be
* considered at once. * considered at once.
@ -745,11 +752,11 @@ public class DoSFilter implements Filter
*/ */
public void setThrottledRequests(int value) public void setThrottledRequests(int value)
{ {
_passes = new Semaphore((value-_throttledRequests+_passes.availablePermits()), true); int permits = _passes == null ? 0 : _passes.availablePermits();
_passes = new Semaphore((value - _throttledRequests + permits), true);
_throttledRequests = value; _throttledRequests = value;
} }
/* ------------------------------------------------------------ */
/** /**
* Get amount of time (in milliseconds) to async wait for semaphore. * Get amount of time (in milliseconds) to async wait for semaphore.
* *
@ -760,7 +767,6 @@ public class DoSFilter implements Filter
return _throttleMs; return _throttleMs;
} }
/* ------------------------------------------------------------ */
/** /**
* Set amount of time (in milliseconds) to async wait for semaphore. * Set amount of time (in milliseconds) to async wait for semaphore.
* *
@ -771,7 +777,6 @@ public class DoSFilter implements Filter
_throttleMs = value; _throttleMs = value;
} }
/* ------------------------------------------------------------ */
/** /**
* Get maximum amount of time (in milliseconds) to allow * Get maximum amount of time (in milliseconds) to allow
* the request to process. * the request to process.
@ -783,7 +788,6 @@ public class DoSFilter implements Filter
return _maxRequestMs; return _maxRequestMs;
} }
/* ------------------------------------------------------------ */
/** /**
* Set maximum amount of time (in milliseconds) to allow * Set maximum amount of time (in milliseconds) to allow
* the request to process. * the request to process.
@ -795,7 +799,6 @@ public class DoSFilter implements Filter
_maxRequestMs = value; _maxRequestMs = value;
} }
/* ------------------------------------------------------------ */
/** /**
* Get maximum amount of time (in milliseconds) to keep track * Get maximum amount of time (in milliseconds) to keep track
* of request rates for a connection, before deciding that * of request rates for a connection, before deciding that
@ -808,7 +811,6 @@ public class DoSFilter implements Filter
return _maxIdleTrackerMs; return _maxIdleTrackerMs;
} }
/* ------------------------------------------------------------ */
/** /**
* Set maximum amount of time (in milliseconds) to keep track * Set maximum amount of time (in milliseconds) to keep track
* of request rates for a connection, before deciding that * of request rates for a connection, before deciding that
@ -821,7 +823,6 @@ public class DoSFilter implements Filter
_maxIdleTrackerMs = value; _maxIdleTrackerMs = value;
} }
/* ------------------------------------------------------------ */
/** /**
* Check flag to insert the DoSFilter headers into the response. * Check flag to insert the DoSFilter headers into the response.
* *
@ -832,7 +833,6 @@ public class DoSFilter implements Filter
return _insertHeaders; return _insertHeaders;
} }
/* ------------------------------------------------------------ */
/** /**
* Set flag to insert the DoSFilter headers into the response. * Set flag to insert the DoSFilter headers into the response.
* *
@ -843,7 +843,6 @@ public class DoSFilter implements Filter
_insertHeaders = value; _insertHeaders = value;
} }
/* ------------------------------------------------------------ */
/** /**
* Get flag to have usage rate tracked by session if a session exists. * Get flag to have usage rate tracked by session if a session exists.
* *
@ -854,9 +853,9 @@ public class DoSFilter implements Filter
return _trackSessions; return _trackSessions;
} }
/* ------------------------------------------------------------ */
/** /**
* Set flag to have usage rate tracked by session if a session exists. * Set flag to have usage rate tracked by session if a session exists.
*
* @param value value of the flag * @param value value of the flag
*/ */
public void setTrackSessions(boolean value) public void setTrackSessions(boolean value)
@ -864,7 +863,6 @@ public class DoSFilter implements Filter
_trackSessions = value; _trackSessions = value;
} }
/* ------------------------------------------------------------ */
/** /**
* Get flag to have usage rate tracked by IP+port (effectively connection) * Get flag to have usage rate tracked by IP+port (effectively connection)
* if session tracking is not used. * if session tracking is not used.
@ -876,8 +874,6 @@ public class DoSFilter implements Filter
return _remotePort; return _remotePort;
} }
/* ------------------------------------------------------------ */
/** /**
* Set flag to have usage rate tracked by IP+port (effectively connection) * Set flag to have usage rate tracked by IP+port (effectively connection)
* if session tracking is not used. * if session tracking is not used.
@ -889,7 +885,22 @@ public class DoSFilter implements Filter
_remotePort = value; _remotePort = value;
} }
/* ------------------------------------------------------------ */ /**
* @return whether this filter is enabled
*/
public boolean isEnabled()
{
return _enabled;
}
/**
* @param enabled whether this filter is enabled
*/
public void setEnabled(boolean enabled)
{
_enabled = enabled;
}
/** /**
* Get a list of IP addresses that will not be rate limited. * Get a list of IP addresses that will not be rate limited.
* *
@ -897,11 +908,17 @@ public class DoSFilter implements Filter
*/ */
public String getWhitelist() public String getWhitelist()
{ {
return _whitelistStr; StringBuilder result = new StringBuilder();
for (Iterator<String> iterator = _whitelist.iterator(); iterator.hasNext();)
{
String address = iterator.next();
result.append(address);
if (iterator.hasNext())
result.append(",");
}
return result.toString();
} }
/* ------------------------------------------------------------ */
/** /**
* Set a list of IP addresses that will not be rate limited. * Set a list of IP addresses that will not be rate limited.
* *
@ -909,8 +926,40 @@ public class DoSFilter implements Filter
*/ */
public void setWhitelist(String value) public void setWhitelist(String value)
{ {
_whitelistStr = value; List<String> result = new ArrayList<String>();
initWhitelist(); for (String address : value.split(","))
addWhitelistAddress(result, address);
_whitelist.clear();
_whitelist.addAll(result);
LOG.debug("Whitelisted IP addresses: {}", result);
}
public void clearWhitelist()
{
_whitelist.clear();
}
public boolean addWhitelistAddress(String address)
{
return addWhitelistAddress(_whitelist, address);
}
private boolean addWhitelistAddress(List<String> list, String address)
{
address = address.trim();
if (address.length() > 0)
{
if (CIDR_PATTERN.matcher(address).matches() || IP_PATTERN.matcher(address).matches())
return list.add(address);
else
LOG.warn("Ignoring malformed whitelist IP address {}", address);
}
return false;
}
public boolean removeWhitelistAddress(String address)
{
return _whitelist.remove(address);
} }
/** /**
@ -924,7 +973,6 @@ public class DoSFilter implements Filter
transient protected final long[] _timestamps; transient protected final long[] _timestamps;
transient protected int _next; transient protected int _next;
public RateTracker(String id, int type, int maxRequestsPerSecond) public RateTracker(String id, int type, int maxRequestsPerSecond)
{ {
_id = id; _id = id;
@ -946,11 +994,9 @@ public class DoSFilter implements Filter
_next = (_next + 1) % _timestamps.length; _next = (_next + 1) % _timestamps.length;
} }
boolean exceeded=last!=0 && (now-last)<1000L; return last != 0 && (now - last) < 1000L;
return exceeded;
} }
public String getId() public String getId()
{ {
return _id; return _id;
@ -961,29 +1007,27 @@ public class DoSFilter implements Filter
return _type; return _type;
} }
public void valueBound(HttpSessionBindingEvent event) public void valueBound(HttpSessionBindingEvent event)
{ {
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("Value bound:"+_id); LOG.debug("Value bound: {}", getId());
} }
public void valueUnbound(HttpSessionBindingEvent event) public void valueUnbound(HttpSessionBindingEvent event)
{ {
//take the tracker out of the list of trackers //take the tracker out of the list of trackers
if (_rateTrackers != null)
_rateTrackers.remove(_id); _rateTrackers.remove(_id);
if (LOG.isDebugEnabled()) LOG.debug("Tracker removed: "+_id); if (LOG.isDebugEnabled())
LOG.debug("Tracker removed: {}", getId());
} }
public void sessionWillPassivate(HttpSessionEvent se) public void sessionWillPassivate(HttpSessionEvent se)
{ {
//take the tracker of the list of trackers (if its still there) //take the tracker of the list of trackers (if its still there)
//and ensure that we take ourselves out of the session so we are not saved //and ensure that we take ourselves out of the session so we are not saved
if (_rateTrackers != null)
_rateTrackers.remove(_id); _rateTrackers.remove(_id);
se.getSession().removeAttribute(__TRACKER); se.getSession().removeAttribute(__TRACKER);
if (LOG.isDebugEnabled()) LOG.debug("Value removed: "+_id); if (LOG.isDebugEnabled()) LOG.debug("Value removed: {}", getId());
} }
public void sessionDidActivate(HttpSessionEvent se) public void sessionDidActivate(HttpSessionEvent se)
@ -991,10 +1035,7 @@ public class DoSFilter implements Filter
LOG.warn("Unexpected session activation"); LOG.warn("Unexpected session activation");
} }
public void expired() public void expired()
{
if (_rateTrackers != null && _trackerTimeoutQ != null)
{ {
long now = _trackerTimeoutQ.getNow(); long now = _trackerTimeoutQ.getNow();
int latestIndex = _next == 0 ? (_timestamps.length - 1) : (_next - 1); int latestIndex = _next == 0 ? (_timestamps.length - 1) : (_next - 1);
@ -1006,15 +1047,12 @@ public class DoSFilter implements Filter
else else
_rateTrackers.remove(_id); _rateTrackers.remove(_id);
} }
}
@Override @Override
public String toString() public String toString()
{ {
return "RateTracker/" + _id + "/" + _type; return "RateTracker/" + _id + "/" + _type;
} }
} }
class FixedRateTracker extends RateTracker class FixedRateTracker extends RateTracker

View File

@ -9,4 +9,10 @@ maxIdleTrackerMs: maximum amount of time (in milliseconds) to keep track of requ
insertHeaders: insert the DoSFilter headers into the response. insertHeaders: insert the DoSFilter headers into the response.
trackSessions: usage rate is tracked by session if a session exists. trackSessions: usage rate is tracked by session if a session exists.
remotePort: usage rate is tracked by IP+port (effectively connection) if session tracking is not used. remotePort: usage rate is tracked by IP+port (effectively connection) if session tracking is not used.
ipWhitelist: list of IP addresses that will not be rate limited. enabled: whether this filter is enabled
whitelist: comma separated list of IP addresses that will not be rate limited.
clearWhitelist(): clears the list of IP addresses that will not be rate limited.
addWhitelistAddress(java.lang.String):ACTION: adds an IP address that will not be rate limited.
addWhitelistAddress(java.lang.String)[0]:address: the IP address that will not be rate limited.
removeWhitelistAddress(java.lang.String):ACTION: removes an IP address that will not be rate limited.
removeWhitelistAddress(java.lang.String)[0]:address: the IP address that will not be rate limited.

View File

@ -0,0 +1,88 @@
//
// ========================================================================
// Copyright (c) 1995-2013 Mort Bay Consulting Pty. Ltd.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
//
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
//
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
//
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
//
package org.eclipse.jetty.servlets;
import java.lang.management.ManagementFactory;
import java.util.EnumSet;
import java.util.Set;
import javax.management.Attribute;
import javax.management.MBeanServer;
import javax.management.ObjectName;
import org.eclipse.jetty.jmx.MBeanContainer;
import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.server.DispatcherType;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.nio.SelectChannelConnector;
import org.eclipse.jetty.servlet.FilterHolder;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.junit.Assert;
import org.junit.Test;
public class DoSFilterJMXTest
{
@Test
public void testDoSFilterJMX() throws Exception
{
Server server = new Server();
Connector connector = new SelectChannelConnector();
connector.setPort(0);
server.addConnector(connector);
ServletContextHandler context = new ServletContextHandler(server, "/", ServletContextHandler.SESSIONS);
DoSFilter filter = new DoSFilter();
FilterHolder holder = new FilterHolder(filter);
String name = "dos";
holder.setName(name);
holder.setInitParameter(DoSFilter.MANAGED_ATTR_INIT_PARAM, "true");
context.addFilter(holder, "/*", EnumSet.of(DispatcherType.REQUEST));
context.setInitParameter(ServletContextHandler.MANAGED_ATTRIBUTES, name);
MBeanServer mbeanServer = ManagementFactory.getPlatformMBeanServer();
MBeanContainer mbeanContainer = new MBeanContainer(mbeanServer);
server.addBean(mbeanContainer);
server.getContainer().addEventListener(mbeanContainer);
server.start();
String domain = DoSFilter.class.getPackage().getName();
Set<ObjectName> mbeanNames = mbeanServer.queryNames(ObjectName.getInstance(domain + ":*"), null);
Assert.assertEquals(1, mbeanNames.size());
ObjectName objectName = mbeanNames.iterator().next();
boolean value = (Boolean)mbeanServer.getAttribute(objectName, "enabled");
mbeanServer.setAttribute(objectName, new Attribute("enabled", !value));
Assert.assertEquals(!value, filter.isEnabled());
String whitelist = (String)mbeanServer.getAttribute(objectName, "whitelist");
String address = "127.0.0.1";
Assert.assertFalse(whitelist.contains(address));
boolean result = (Boolean)mbeanServer.invoke(objectName, "addWhitelistAddress", new Object[]{address}, new String[]{String.class.getName()});
Assert.assertTrue(result);
whitelist = (String)mbeanServer.getAttribute(objectName, "whitelist");
Assert.assertTrue(whitelist.contains(address));
result = (Boolean)mbeanServer.invoke(objectName, "removeWhitelistAddress", new Object[]{address}, new String[]{String.class.getName()});
Assert.assertTrue(result);
whitelist = (String)mbeanServer.getAttribute(objectName, "whitelist");
Assert.assertFalse(whitelist.contains(address));
server.stop();
}
}

View File

@ -18,18 +18,21 @@
package org.eclipse.jetty.servlets; package org.eclipse.jetty.servlets;
import static org.junit.Assert.assertFalse; import java.util.ArrayList;
import static org.junit.Assert.assertTrue; import java.util.List;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import org.eclipse.jetty.servlets.DoSFilter.RateTracker; import org.eclipse.jetty.servlets.DoSFilter.RateTracker;
import org.eclipse.jetty.util.log.Log; import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger; import org.eclipse.jetty.util.log.Logger;
import org.junit.Assert;
import org.junit.BeforeClass; import org.junit.BeforeClass;
import org.junit.Test; import org.junit.Test;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
public class DoSFilterTest extends AbstractDoSFilterTest public class DoSFilterTest extends AbstractDoSFilterTest
{ {
private static final Logger LOG = Log.getLogger(DoSFilterTest.class); private static final Logger LOG = Log.getLogger(DoSFilterTest.class);
@ -69,6 +72,21 @@ public class DoSFilterTest extends AbstractDoSFilterTest
assertFalse("Should not exceed as we sleep 300s for each hit and thus do less than 4 hits/s",exceeded); assertFalse("Should not exceed as we sleep 300s for each hit and thus do less than 4 hits/s",exceeded);
} }
@Test
public void testWhitelist() throws Exception
{
DoSFilter filter = new DoSFilter();
List<String> whitelist = new ArrayList<String>();
whitelist.add("192.168.0.1");
whitelist.add("10.0.0.0/8");
Assert.assertTrue(filter.checkWhitelist(whitelist, "192.168.0.1"));
Assert.assertFalse(filter.checkWhitelist(whitelist, "192.168.0.2"));
Assert.assertFalse(filter.checkWhitelist(whitelist, "11.12.13.14"));
Assert.assertTrue(filter.checkWhitelist(whitelist, "10.11.12.13"));
Assert.assertTrue(filter.checkWhitelist(whitelist, "10.0.0.0"));
Assert.assertFalse(filter.checkWhitelist(whitelist, "0.0.0.0"));
}
private boolean hitRateTracker(DoSFilter doSFilter, int sleep) throws InterruptedException private boolean hitRateTracker(DoSFilter doSFilter, int sleep) throws InterruptedException
{ {
boolean exceeded = false; boolean exceeded = false;