Use only IP tracking for the DosFilter to fix #1256 Signed-off-by: gregw <gregw@webtide.com>
This commit is contained in:
parent
8c94490e18
commit
168d8715d4
|
@ -32,9 +32,7 @@ The filter works on the assumption that the attacker might be written in simple
|
||||||
[[dos-filter-using]]
|
[[dos-filter-using]]
|
||||||
==== Using the DoS Filter
|
==== Using the DoS Filter
|
||||||
|
|
||||||
Jetty places throttled requests in a priority queue, giving priority first to authenticated users and users with an HttpSession, then to connections identified by their IP addresses.
|
Jetty places throttled requests in a queue, and proceed only when there is capacity available.
|
||||||
Connections with no way to identify them have lowest priority.
|
|
||||||
To uniquely identify authenticated users, you should implement the The extractUserId(ServletRequest request) function.
|
|
||||||
|
|
||||||
===== Required JARs
|
===== Required JARs
|
||||||
|
|
||||||
|
@ -94,11 +92,8 @@ Default is 30000L.
|
||||||
insertHeaders::
|
insertHeaders::
|
||||||
If true, insert the DoSFilter headers into the response.
|
If true, insert the DoSFilter headers into the response.
|
||||||
Defaults to true.
|
Defaults to true.
|
||||||
trackSessions::
|
|
||||||
If true, usage rate is tracked by session if a session exists.
|
|
||||||
Defaults to true.
|
|
||||||
remotePort::
|
remotePort::
|
||||||
If true and session tracking is not used, then rate is tracked by IP and port (effectively connection).
|
If true, then rate is tracked by IP and port (effectively connection).
|
||||||
Defaults to false.
|
Defaults to false.
|
||||||
ipWhitelist::
|
ipWhitelist::
|
||||||
A comma-separated list of IP addresses that will not be rate limited.
|
A comma-separated list of IP addresses that will not be rate limited.
|
||||||
|
|
|
@ -17,10 +17,8 @@ import java.io.IOException;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Queue;
|
import java.util.Queue;
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
@ -43,11 +41,6 @@ import javax.servlet.ServletRequest;
|
||||||
import javax.servlet.ServletResponse;
|
import javax.servlet.ServletResponse;
|
||||||
import javax.servlet.http.HttpServletRequest;
|
import javax.servlet.http.HttpServletRequest;
|
||||||
import javax.servlet.http.HttpServletResponse;
|
import javax.servlet.http.HttpServletResponse;
|
||||||
import javax.servlet.http.HttpSession;
|
|
||||||
import javax.servlet.http.HttpSessionActivationListener;
|
|
||||||
import javax.servlet.http.HttpSessionBindingEvent;
|
|
||||||
import javax.servlet.http.HttpSessionBindingListener;
|
|
||||||
import javax.servlet.http.HttpSessionEvent;
|
|
||||||
|
|
||||||
import org.eclipse.jetty.http.HttpStatus;
|
import org.eclipse.jetty.http.HttpStatus;
|
||||||
import org.eclipse.jetty.server.handler.ContextHandler;
|
import org.eclipse.jetty.server.handler.ContextHandler;
|
||||||
|
@ -74,10 +67,8 @@ import org.slf4j.LoggerFactory;
|
||||||
* second. If a limit is exceeded, the request is either rejected, delayed, or
|
* second. If a limit is exceeded, the request is either rejected, delayed, or
|
||||||
* throttled.
|
* throttled.
|
||||||
* <p>
|
* <p>
|
||||||
* When a request is throttled, it is placed in a priority queue. Priority is
|
* When a request is throttled, it is placed in a queue and will only proceed
|
||||||
* given first to authenticated users and users with an HttpSession, then
|
* when there is capacity.
|
||||||
* connections which can be identified by their IP addresses. Connections with
|
|
||||||
* no way to identify them are given lowest priority.
|
|
||||||
* <p>
|
* <p>
|
||||||
* The {@link #extractUserId(ServletRequest request)} function should be
|
* The {@link #extractUserId(ServletRequest request)} function should be
|
||||||
* implemented, in order to uniquely identify authenticated users.
|
* implemented, in order to uniquely identify authenticated users.
|
||||||
|
@ -106,10 +97,8 @@ import org.slf4j.LoggerFactory;
|
||||||
* before deciding that the user has gone away, and discarding it</dd>
|
* before deciding that the user has gone away, and discarding it</dd>
|
||||||
* <dt>insertHeaders</dt>
|
* <dt>insertHeaders</dt>
|
||||||
* <dd>if true , insert the DoSFilter headers into the response. Defaults to true.</dd>
|
* <dd>if true , insert the DoSFilter headers into the response. Defaults to true.</dd>
|
||||||
* <dt>trackSessions</dt>
|
|
||||||
* <dd>if true, usage rate is tracked by session if a session exists. Defaults to true.</dd>
|
|
||||||
* <dt>remotePort</dt>
|
* <dt>remotePort</dt>
|
||||||
* <dd>if true and session tracking is not used, then rate is tracked by IP+port (effectively connection). Defaults to false.</dd>
|
* <dd>if true then rate is tracked by IP+port (effectively connection). Defaults to false.</dd>
|
||||||
* <dt>ipWhitelist</dt>
|
* <dt>ipWhitelist</dt>
|
||||||
* <dd>a comma-separated list of IP addresses that will not be rate limited</dd>
|
* <dd>a comma-separated list of IP addresses that will not be rate limited</dd>
|
||||||
* <dt>managedAttr</dt>
|
* <dt>managedAttr</dt>
|
||||||
|
@ -156,12 +145,14 @@ public class DoSFilter implements Filter
|
||||||
static final String MAX_REQUEST_MS_INIT_PARAM = "maxRequestMs";
|
static final String MAX_REQUEST_MS_INIT_PARAM = "maxRequestMs";
|
||||||
static final String MAX_IDLE_TRACKER_MS_INIT_PARAM = "maxIdleTrackerMs";
|
static final String MAX_IDLE_TRACKER_MS_INIT_PARAM = "maxIdleTrackerMs";
|
||||||
static final String INSERT_HEADERS_INIT_PARAM = "insertHeaders";
|
static final String INSERT_HEADERS_INIT_PARAM = "insertHeaders";
|
||||||
|
@Deprecated
|
||||||
static final String TRACK_SESSIONS_INIT_PARAM = "trackSessions";
|
static final String TRACK_SESSIONS_INIT_PARAM = "trackSessions";
|
||||||
static final String REMOTE_PORT_INIT_PARAM = "remotePort";
|
static final String REMOTE_PORT_INIT_PARAM = "remotePort";
|
||||||
static final String IP_WHITELIST_INIT_PARAM = "ipWhitelist";
|
static final String IP_WHITELIST_INIT_PARAM = "ipWhitelist";
|
||||||
static final String ENABLED_INIT_PARAM = "enabled";
|
static final String ENABLED_INIT_PARAM = "enabled";
|
||||||
static final String TOO_MANY_CODE = "tooManyCode";
|
static final String TOO_MANY_CODE = "tooManyCode";
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public enum RateType
|
public enum RateType
|
||||||
{
|
{
|
||||||
AUTH,
|
AUTH,
|
||||||
|
@ -181,7 +172,6 @@ public class DoSFilter implements Filter
|
||||||
private volatile long _maxRequestMs;
|
private volatile long _maxRequestMs;
|
||||||
private volatile long _maxIdleTrackerMs;
|
private volatile long _maxIdleTrackerMs;
|
||||||
private volatile boolean _insertHeaders;
|
private volatile boolean _insertHeaders;
|
||||||
private volatile boolean _trackSessions;
|
|
||||||
private volatile boolean _remotePort;
|
private volatile boolean _remotePort;
|
||||||
private volatile boolean _enabled;
|
private volatile boolean _enabled;
|
||||||
private volatile String _name;
|
private volatile String _name;
|
||||||
|
@ -189,20 +179,14 @@ public class DoSFilter implements Filter
|
||||||
private Semaphore _passes;
|
private Semaphore _passes;
|
||||||
private volatile int _throttledRequests;
|
private volatile int _throttledRequests;
|
||||||
private volatile int _maxRequestsPerSec;
|
private volatile int _maxRequestsPerSec;
|
||||||
private Map<RateType, Queue<AsyncContext>> _queues = new HashMap<>();
|
private final Queue<AsyncContext> _queue = new ConcurrentLinkedQueue<>();
|
||||||
private Map<RateType, AsyncListener> _listeners = new HashMap<>();
|
private final AsyncListener _asyncListener = new DoSAsyncListener();
|
||||||
private Scheduler _scheduler;
|
private Scheduler _scheduler;
|
||||||
private ServletContext _context;
|
private ServletContext _context;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void init(FilterConfig filterConfig) throws ServletException
|
public void init(FilterConfig filterConfig) throws ServletException
|
||||||
{
|
{
|
||||||
for (RateType rateType : RateType.values())
|
|
||||||
{
|
|
||||||
_queues.put(rateType, new ConcurrentLinkedQueue<>());
|
|
||||||
_listeners.put(rateType, new DoSAsyncListener(rateType));
|
|
||||||
}
|
|
||||||
|
|
||||||
_rateTrackers.clear();
|
_rateTrackers.clear();
|
||||||
|
|
||||||
int maxRequests = __DEFAULT_MAX_REQUESTS_PER_SEC;
|
int maxRequests = __DEFAULT_MAX_REQUESTS_PER_SEC;
|
||||||
|
@ -395,15 +379,14 @@ public class DoSFilter implements Filter
|
||||||
long throttleMs = getThrottleMs();
|
long throttleMs = getThrottleMs();
|
||||||
if (!Boolean.TRUE.equals(throttled) && throttleMs > 0)
|
if (!Boolean.TRUE.equals(throttled) && throttleMs > 0)
|
||||||
{
|
{
|
||||||
RateType priority = getPriority(request, tracker);
|
|
||||||
request.setAttribute(__THROTTLED, Boolean.TRUE);
|
request.setAttribute(__THROTTLED, Boolean.TRUE);
|
||||||
if (isInsertHeaders())
|
if (isInsertHeaders())
|
||||||
response.addHeader("DoSFilter", "throttled");
|
response.addHeader("DoSFilter", "throttled");
|
||||||
AsyncContext asyncContext = request.startAsync();
|
AsyncContext asyncContext = request.startAsync();
|
||||||
request.setAttribute(_suspended, Boolean.TRUE);
|
request.setAttribute(_suspended, Boolean.TRUE);
|
||||||
asyncContext.setTimeout(throttleMs);
|
asyncContext.setTimeout(throttleMs);
|
||||||
asyncContext.addListener(_listeners.get(priority));
|
asyncContext.addListener(_asyncListener);
|
||||||
_queues.get(priority).add(asyncContext);
|
_queue.add(asyncContext);
|
||||||
if (LOG.isDebugEnabled())
|
if (LOG.isDebugEnabled())
|
||||||
LOG.debug("Throttled {}, {}ms", request, throttleMs);
|
LOG.debug("Throttled {}, {}ms", request, throttleMs);
|
||||||
return;
|
return;
|
||||||
|
@ -447,10 +430,7 @@ public class DoSFilter implements Filter
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
// Wake up the next highest priority request.
|
AsyncContext asyncContext = _queue.poll();
|
||||||
for (RateType rateType : RateType.values())
|
|
||||||
{
|
|
||||||
AsyncContext asyncContext = _queues.get(rateType).poll();
|
|
||||||
if (asyncContext != null)
|
if (asyncContext != null)
|
||||||
{
|
{
|
||||||
ServletRequest candidate = asyncContext.getRequest();
|
ServletRequest candidate = asyncContext.getRequest();
|
||||||
|
@ -461,8 +441,6 @@ public class DoSFilter implements Filter
|
||||||
LOG.debug("Resuming {}", request);
|
LOG.debug("Resuming {}", request);
|
||||||
candidate.setAttribute(_resumed, Boolean.TRUE);
|
candidate.setAttribute(_resumed, Boolean.TRUE);
|
||||||
asyncContext.dispatch();
|
asyncContext.dispatch();
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -524,27 +502,13 @@ public class DoSFilter implements Filter
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get priority for this request, based on user type
|
* @return null
|
||||||
*
|
* @deprecated Priority no longer supported
|
||||||
* @param request the current request
|
|
||||||
* @param tracker the rate tracker for this request
|
|
||||||
* @return the priority for this request
|
|
||||||
*/
|
|
||||||
private RateType getPriority(HttpServletRequest request, RateTracker tracker)
|
|
||||||
{
|
|
||||||
if (extractUserId(request) != null)
|
|
||||||
return RateType.AUTH;
|
|
||||||
if (tracker != null)
|
|
||||||
return tracker.getType();
|
|
||||||
return RateType.UNKNOWN;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @return the maximum priority that we can assign to a request
|
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
protected RateType getMaxPriority()
|
protected RateType getMaxPriority()
|
||||||
{
|
{
|
||||||
return RateType.AUTH;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setListener(DoSFilter.Listener listener)
|
public void setListener(DoSFilter.Listener listener)
|
||||||
|
@ -570,62 +534,30 @@ public class DoSFilter implements Filter
|
||||||
* <p>
|
* <p>
|
||||||
* Assumes that each connection has an identifying characteristic, and goes
|
* Assumes that each connection has an identifying characteristic, and goes
|
||||||
* through them in order, taking the first that matches: user id (logged
|
* through them in order, taking the first that matches: user id (logged
|
||||||
* in), session id, client IP address. Unidentifiable connections are lumped
|
* in), client IP address. Unidentifiable connections are lumped
|
||||||
* into one.
|
* into one.
|
||||||
* <p>
|
|
||||||
* When a session expires, its rate tracker is automatically deleted.
|
|
||||||
*
|
*
|
||||||
* @param request the current request
|
* @param request the current request
|
||||||
* @return the request rate tracker for the current connection
|
* @return the request rate tracker for the current connection
|
||||||
*/
|
*/
|
||||||
RateTracker getRateTracker(ServletRequest request)
|
RateTracker getRateTracker(ServletRequest request)
|
||||||
{
|
{
|
||||||
HttpSession session = ((HttpServletRequest)request).getSession(false);
|
String loadId = isRemotePort() ? createRemotePortId(request) : request.getRemoteAddr();
|
||||||
|
|
||||||
String loadId = extractUserId(request);
|
|
||||||
final RateType type;
|
|
||||||
if (loadId != null)
|
|
||||||
{
|
|
||||||
type = RateType.AUTH;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
if (isTrackSessions() && session != null && !session.isNew())
|
|
||||||
{
|
|
||||||
loadId = session.getId();
|
|
||||||
type = RateType.SESSION;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
loadId = isRemotePort() ? createRemotePortId(request) : request.getRemoteAddr();
|
|
||||||
type = RateType.IP;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RateTracker tracker = _rateTrackers.get(loadId);
|
RateTracker tracker = _rateTrackers.get(loadId);
|
||||||
|
|
||||||
if (tracker == null)
|
if (tracker == null)
|
||||||
{
|
{
|
||||||
boolean allowed = checkWhitelist(request.getRemoteAddr());
|
boolean allowed = checkWhitelist(request.getRemoteAddr());
|
||||||
int maxRequestsPerSec = getMaxRequestsPerSec();
|
int maxRequestsPerSec = getMaxRequestsPerSec();
|
||||||
tracker = allowed ? new FixedRateTracker(_context, _name, loadId, type, maxRequestsPerSec)
|
tracker = allowed ? new FixedRateTracker(_context, _name, loadId, maxRequestsPerSec)
|
||||||
: new RateTracker(_context, _name, loadId, type, maxRequestsPerSec);
|
: new RateTracker(_context, _name, loadId, maxRequestsPerSec);
|
||||||
tracker.setContext(_context);
|
tracker.setContext(_context);
|
||||||
RateTracker existing = _rateTrackers.putIfAbsent(loadId, tracker);
|
RateTracker existing = _rateTrackers.putIfAbsent(loadId, tracker);
|
||||||
if (existing != null)
|
if (existing != null)
|
||||||
tracker = existing;
|
tracker = existing;
|
||||||
|
|
||||||
if (type == RateType.IP)
|
|
||||||
{
|
|
||||||
// USER_IP expiration from _rateTrackers is handled by the _scheduler
|
|
||||||
_scheduler.schedule(tracker, getMaxIdleTrackerMs(), TimeUnit.MILLISECONDS);
|
_scheduler.schedule(tracker, getMaxIdleTrackerMs(), TimeUnit.MILLISECONDS);
|
||||||
}
|
}
|
||||||
else if (session != null)
|
|
||||||
{
|
|
||||||
// USER_SESSION expiration from _rateTrackers are handled by the HttpSessionBindingListener
|
|
||||||
session.setAttribute(__TRACKER, tracker);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return tracker;
|
return tracker;
|
||||||
}
|
}
|
||||||
|
@ -750,7 +682,7 @@ public class DoSFilter implements Filter
|
||||||
return result;
|
return result;
|
||||||
|
|
||||||
// Sets the _prefix_ most significant bits to 1
|
// Sets the _prefix_ most significant bits to 1
|
||||||
result[index] = (byte)~((1 << (8 - prefix)) - 1);
|
result[index] = (byte)-(1 << (8 - prefix));
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -776,12 +708,11 @@ public class DoSFilter implements Filter
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the user id, used to track this connection.
|
* @param request ignored
|
||||||
* This SHOULD be overridden by subclasses.
|
* @return null
|
||||||
*
|
* @deprecated User ID no longer supported
|
||||||
* @param request the current request
|
|
||||||
* @return a unique user id, if logged in; otherwise null.
|
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
protected String extractUserId(ServletRequest request)
|
protected String extractUserId(ServletRequest request)
|
||||||
{
|
{
|
||||||
return null;
|
return null;
|
||||||
|
@ -996,26 +927,29 @@ public class DoSFilter implements Filter
|
||||||
* Get flag to have usage rate tracked by session if a session exists.
|
* Get flag to have usage rate tracked by session if a session exists.
|
||||||
*
|
*
|
||||||
* @return value of the flag
|
* @return value of the flag
|
||||||
|
* @deprecated Session tracking is no longer supported
|
||||||
*/
|
*/
|
||||||
@ManagedAttribute("usage rate is tracked by session if one exists")
|
@Deprecated
|
||||||
public boolean isTrackSessions()
|
public boolean isTrackSessions()
|
||||||
{
|
{
|
||||||
return _trackSessions;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set flag to have usage rate tracked by session if a session exists.
|
* Set flag to have usage rate tracked by session if a session exists.
|
||||||
*
|
*
|
||||||
* @param value value of the flag
|
* @param value value of the flag
|
||||||
|
* @deprecated Session tracking is no longer supported
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public void setTrackSessions(boolean value)
|
public void setTrackSessions(boolean value)
|
||||||
{
|
{
|
||||||
_trackSessions = value;
|
if (value)
|
||||||
|
LOG.warn("Session Tracking is not supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get flag to have usage rate tracked by IP+port (effectively connection)
|
* Get flag to have usage rate tracked by IP+port (effectively connection)
|
||||||
* if session tracking is not used.
|
|
||||||
*
|
*
|
||||||
* @return value of the flag
|
* @return value of the flag
|
||||||
*/
|
*/
|
||||||
|
@ -1027,7 +961,6 @@ public class DoSFilter implements Filter
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set flag to have usage rate tracked by IP+port (effectively connection)
|
* Set flag to have usage rate tracked by IP+port (effectively connection)
|
||||||
* if session tracking is not used.
|
|
||||||
*
|
*
|
||||||
* @param value value of the flag
|
* @param value value of the flag
|
||||||
*/
|
*/
|
||||||
|
@ -1130,7 +1063,10 @@ public class DoSFilter implements Filter
|
||||||
private boolean addWhitelistAddress(List<String> list, String address)
|
private boolean addWhitelistAddress(List<String> list, String address)
|
||||||
{
|
{
|
||||||
address = address.trim();
|
address = address.trim();
|
||||||
return address.length() > 0 && list.add(address);
|
if (address.length() <= 0)
|
||||||
|
return false;
|
||||||
|
list.add(address);
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1157,7 +1093,7 @@ public class DoSFilter implements Filter
|
||||||
* A RateTracker is associated with a connection, and stores request rate
|
* A RateTracker is associated with a connection, and stores request rate
|
||||||
* data.
|
* data.
|
||||||
*/
|
*/
|
||||||
static class RateTracker implements Runnable, HttpSessionBindingListener, HttpSessionActivationListener, Serializable
|
static class RateTracker implements Runnable, Serializable
|
||||||
{
|
{
|
||||||
private static final long serialVersionUID = 3534663738034577872L;
|
private static final long serialVersionUID = 3534663738034577872L;
|
||||||
|
|
||||||
|
@ -1165,18 +1101,16 @@ public class DoSFilter implements Filter
|
||||||
protected final String _filterName;
|
protected final String _filterName;
|
||||||
protected transient ServletContext _context;
|
protected transient ServletContext _context;
|
||||||
protected final String _id;
|
protected final String _id;
|
||||||
protected final RateType _type;
|
|
||||||
protected final int _maxRequestsPerSecond;
|
protected final int _maxRequestsPerSecond;
|
||||||
protected final long[] _timestamps;
|
protected final long[] _timestamps;
|
||||||
|
|
||||||
protected int _next;
|
protected int _next;
|
||||||
|
|
||||||
public RateTracker(ServletContext context, String filterName, String id, RateType type, int maxRequestsPerSecond)
|
RateTracker(ServletContext context, String filterName, String id, int maxRequestsPerSecond)
|
||||||
{
|
{
|
||||||
_context = context;
|
_context = context;
|
||||||
_filterName = filterName;
|
_filterName = filterName;
|
||||||
_id = id;
|
_id = id;
|
||||||
_type = type;
|
|
||||||
_maxRequestsPerSecond = maxRequestsPerSecond;
|
_maxRequestsPerSecond = maxRequestsPerSecond;
|
||||||
_timestamps = new long[maxRequestsPerSecond];
|
_timestamps = new long[maxRequestsPerSecond];
|
||||||
_next = 0;
|
_next = 0;
|
||||||
|
@ -1212,52 +1146,6 @@ public class DoSFilter implements Filter
|
||||||
return _id;
|
return _id;
|
||||||
}
|
}
|
||||||
|
|
||||||
public RateType getType()
|
|
||||||
{
|
|
||||||
return _type;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void valueBound(HttpSessionBindingEvent event)
|
|
||||||
{
|
|
||||||
if (LOG.isDebugEnabled())
|
|
||||||
LOG.debug("Value bound: {}", getId());
|
|
||||||
_context = event.getSession().getServletContext();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void valueUnbound(HttpSessionBindingEvent event)
|
|
||||||
{
|
|
||||||
//take the tracker out of the list of trackers
|
|
||||||
DoSFilter filter = (DoSFilter)event.getSession().getServletContext().getAttribute(_filterName);
|
|
||||||
removeFromRateTrackers(filter, _id);
|
|
||||||
_context = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void sessionWillPassivate(HttpSessionEvent se)
|
|
||||||
{
|
|
||||||
//take the tracker of the list of trackers (if its still there)
|
|
||||||
DoSFilter filter = (DoSFilter)se.getSession().getServletContext().getAttribute(_filterName);
|
|
||||||
removeFromRateTrackers(filter, _id);
|
|
||||||
_context = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void sessionDidActivate(HttpSessionEvent se)
|
|
||||||
{
|
|
||||||
RateTracker tracker = (RateTracker)se.getSession().getAttribute(__TRACKER);
|
|
||||||
ServletContext context = se.getSession().getServletContext();
|
|
||||||
tracker.setContext(context);
|
|
||||||
DoSFilter filter = (DoSFilter)context.getAttribute(_filterName);
|
|
||||||
if (filter == null)
|
|
||||||
{
|
|
||||||
LOG.info("No filter {} for rate tracker {}", _filterName, tracker);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
addToRateTrackers(filter, tracker);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setContext(ServletContext context)
|
public void setContext(ServletContext context)
|
||||||
{
|
{
|
||||||
_context = context;
|
_context = context;
|
||||||
|
@ -1309,7 +1197,7 @@ public class DoSFilter implements Filter
|
||||||
@Override
|
@Override
|
||||||
public String toString()
|
public String toString()
|
||||||
{
|
{
|
||||||
return "RateTracker/" + _id + "/" + _type;
|
return "RateTracker/" + _id;
|
||||||
}
|
}
|
||||||
|
|
||||||
public class Overage implements OverLimit
|
public class Overage implements OverLimit
|
||||||
|
@ -1323,12 +1211,6 @@ public class DoSFilter implements Filter
|
||||||
this.count = count;
|
this.count = count;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public RateType getRateType()
|
|
||||||
{
|
|
||||||
return _type;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getRateId()
|
public String getRateId()
|
||||||
{
|
{
|
||||||
|
@ -1350,23 +1232,20 @@ public class DoSFilter implements Filter
|
||||||
@Override
|
@Override
|
||||||
public String toString()
|
public String toString()
|
||||||
{
|
{
|
||||||
final StringBuilder sb = new StringBuilder(OverLimit.class.getSimpleName());
|
return OverLimit.class.getSimpleName() + '@' + Integer.toHexString(hashCode()) +
|
||||||
sb.append('@').append(Integer.toHexString(hashCode()));
|
"[id=" + getRateId() +
|
||||||
sb.append("[type=").append(getRateType());
|
", duration=" + duration +
|
||||||
sb.append(", id=").append(getRateId());
|
", count=" + count +
|
||||||
sb.append(", duration=").append(duration);
|
']';
|
||||||
sb.append(", count=").append(count);
|
|
||||||
sb.append(']');
|
|
||||||
return sb.toString();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class FixedRateTracker extends RateTracker
|
private static class FixedRateTracker extends RateTracker
|
||||||
{
|
{
|
||||||
public FixedRateTracker(ServletContext context, String filterName, String id, RateType type, int numRecentRequestsTracked)
|
public FixedRateTracker(ServletContext context, String filterName, String id, int numRecentRequestsTracked)
|
||||||
{
|
{
|
||||||
super(context, filterName, id, type, numRecentRequestsTracked);
|
super(context, filterName, id, numRecentRequestsTracked);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -1417,17 +1296,14 @@ public class DoSFilter implements Filter
|
||||||
|
|
||||||
private class DoSAsyncListener extends DoSTimeoutAsyncListener
|
private class DoSAsyncListener extends DoSTimeoutAsyncListener
|
||||||
{
|
{
|
||||||
private final RateType priority;
|
public DoSAsyncListener()
|
||||||
|
|
||||||
public DoSAsyncListener(RateType priority)
|
|
||||||
{
|
{
|
||||||
this.priority = priority;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onTimeout(AsyncEvent event) throws IOException
|
public void onTimeout(AsyncEvent event) throws IOException
|
||||||
{
|
{
|
||||||
_queues.get(priority).remove(event.getAsyncContext());
|
_queue.remove(event.getAsyncContext()); // TODO what???
|
||||||
super.onTimeout(event);
|
super.onTimeout(event);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1475,8 +1351,6 @@ public class DoSFilter implements Filter
|
||||||
|
|
||||||
public interface OverLimit
|
public interface OverLimit
|
||||||
{
|
{
|
||||||
RateType getRateType();
|
|
||||||
|
|
||||||
String getRateId();
|
String getRateId();
|
||||||
|
|
||||||
Duration getDuration();
|
Duration getDuration();
|
||||||
|
@ -1503,13 +1377,13 @@ public class DoSFilter implements Filter
|
||||||
switch (action)
|
switch (action)
|
||||||
{
|
{
|
||||||
case REJECT:
|
case REJECT:
|
||||||
LOG.warn("DOS ALERT: Request rejected ip={}, overlimit={}, session={}, user={}", request.getRemoteAddr(), overlimit, request.getRequestedSessionId(), request.getUserPrincipal());
|
LOG.warn("DoS ALERT: Request rejected ip={}, overlimit={}, user={}", request.getRemoteAddr(), overlimit, request.getUserPrincipal());
|
||||||
break;
|
break;
|
||||||
case DELAY:
|
case DELAY:
|
||||||
LOG.warn("DOS ALERT: Request delayed={}ms, ip={}, overlimit={}, session={}, user={}", dosFilter.getDelayMs(), request.getRemoteAddr(), overlimit, request.getRequestedSessionId(), request.getUserPrincipal());
|
LOG.warn("DoS ALERT: Request delayed={}ms, ip={}, overlimit={}, user={}", dosFilter.getDelayMs(), request.getRemoteAddr(), overlimit, request.getUserPrincipal());
|
||||||
break;
|
break;
|
||||||
case THROTTLE:
|
case THROTTLE:
|
||||||
LOG.warn("DOS ALERT: Request throttled ip={}, overlimit={}, session={}, user={}", request.getRemoteAddr(), overlimit, request.getRequestedSessionId(), request.getUserPrincipal());
|
LOG.warn("DoS ALERT: Request throttled ip={}, overlimit={}, user={}", request.getRemoteAddr(), overlimit, request.getUserPrincipal());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -45,9 +45,7 @@ import org.junit.jupiter.api.Test;
|
||||||
import static org.hamcrest.MatcherAssert.assertThat;
|
import static org.hamcrest.MatcherAssert.assertThat;
|
||||||
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
|
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
|
||||||
import static org.hamcrest.Matchers.is;
|
import static org.hamcrest.Matchers.is;
|
||||||
import static org.hamcrest.Matchers.lessThan;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
|
||||||
|
|
||||||
public abstract class AbstractDoSFilterTest
|
public abstract class AbstractDoSFilterTest
|
||||||
{
|
{
|
||||||
|
@ -248,60 +246,6 @@ public abstract class AbstractDoSFilterTest
|
||||||
other.join();
|
other.join();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testSessionTracking() throws Exception
|
|
||||||
{
|
|
||||||
// get a session, first
|
|
||||||
String requestSession = "GET /ctx/dos/test?session=true HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
|
|
||||||
String response = doRequests("", 1, 0, 0, requestSession);
|
|
||||||
String sessionId = response.substring(response.indexOf("Set-Cookie: ") + 12, response.indexOf(";"));
|
|
||||||
|
|
||||||
// all other requests use this session
|
|
||||||
String request = "GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nCookie: " + sessionId + "\r\n\r\n";
|
|
||||||
String last = "GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: " + sessionId + "\r\n\r\n";
|
|
||||||
String responses = doRequests(request + request + request + request + request, 2, 1100, 1100, last);
|
|
||||||
|
|
||||||
assertEquals(11, count(responses, "HTTP/1.1 200 OK"));
|
|
||||||
assertEquals(2, count(responses, "DoSFilter: delayed"));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testMultipleSessionTracking() throws Exception
|
|
||||||
{
|
|
||||||
// get some session ids, first
|
|
||||||
String requestSession = "GET /ctx/dos/test?session=true HTTP/1.1\r\nHost: localhost\r\n\r\n";
|
|
||||||
String closeRequest = "GET /ctx/dos/test?session=true HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
|
|
||||||
String response = doRequests(requestSession + requestSession, 1, 0, 0, closeRequest);
|
|
||||||
|
|
||||||
String[] sessions = response.split("\r\n\r\n");
|
|
||||||
|
|
||||||
String sessionId1 = sessions[0].substring(sessions[0].indexOf("Set-Cookie: ") + 12, sessions[0].indexOf(";"));
|
|
||||||
String sessionId2 = sessions[1].substring(sessions[1].indexOf("Set-Cookie: ") + 12, sessions[1].indexOf(";"));
|
|
||||||
|
|
||||||
// alternate between sessions
|
|
||||||
String request1 = "GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nCookie: " + sessionId1 + "\r\n\r\n";
|
|
||||||
String request2 = "GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nCookie: " + sessionId2 + "\r\n\r\n";
|
|
||||||
String last = "GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: " + sessionId2 + "\r\n\r\n";
|
|
||||||
|
|
||||||
// ensure the sessions are new
|
|
||||||
doRequests(request1 + request2, 1, 1100, 1100, last);
|
|
||||||
Thread.sleep(1000);
|
|
||||||
|
|
||||||
String responses = doRequests(request1 + request2 + request1 + request2 + request1, 2, 1100, 1100, last);
|
|
||||||
|
|
||||||
assertEquals(11, count(responses, "HTTP/1.1 200 OK"));
|
|
||||||
// This test is system speed dependent, so allow some (20%-ish) requests to be delayed, but not more.
|
|
||||||
assertThat("delayed count", count(responses, "DoSFilter: delayed"), lessThan(2));
|
|
||||||
|
|
||||||
// alternate between sessions
|
|
||||||
responses = doRequests(request1 + request2 + request1 + request2 + request1, 2, 250, 250, last);
|
|
||||||
|
|
||||||
// System.err.println(responses);
|
|
||||||
assertEquals(11, count(responses, "HTTP/1.1 200 OK"));
|
|
||||||
int delayedRequests = count(responses, "DoSFilter: delayed");
|
|
||||||
assertTrue(delayedRequests >= 2 && delayedRequests <= 5, "delayedRequests: " + delayedRequests + " is not between 2 and 5");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testUnresponsiveClient() throws Exception
|
public void testUnresponsiveClient() throws Exception
|
||||||
{
|
{
|
||||||
|
|
|
@ -169,7 +169,7 @@ public class DoSFilterTest extends AbstractDoSFilterTest
|
||||||
{
|
{
|
||||||
boolean exceeded = false;
|
boolean exceeded = false;
|
||||||
ServletContext context = new ContextHandler.StaticContext();
|
ServletContext context = new ContextHandler.StaticContext();
|
||||||
RateTracker rateTracker = new RateTracker(context, doSFilter.getName(), "test2", DoSFilter.RateType.UNKNOWN, 4);
|
RateTracker rateTracker = new RateTracker(context, doSFilter.getName(), "test2", 4);
|
||||||
|
|
||||||
for (int i = 0; i < 5; i++)
|
for (int i = 0; i < 5; i++)
|
||||||
{
|
{
|
||||||
|
|
Loading…
Reference in New Issue