Issue #5185 - Add DoSFilter Listener to allow extensible behavior

+ Currently there's no way to respond to rejected/throttled/delayed
  requests that the DoSFilter impacts.
  A Listener has been added to allow for any behaviors needed
  by a user of the DoSFilter on requests that have been
  impacted by the DoSFilter.
+ Introducing OverLimit and RateType to DoSFilter internals

Signed-off-by: Joakim Erdfelt <joakim.erdfelt@gmail.com>
This commit is contained in:
Joakim Erdfelt 2020-08-24 09:50:50 -05:00
parent 5fef14019a
commit a8ae3f9476
No known key found for this signature in database
GPG Key ID: 2D0E1FB8FE4B68B4
2 changed files with 278 additions and 102 deletions

View File

@ -20,9 +20,13 @@ package org.eclipse.jetty.servlets;
import java.io.IOException;
import java.io.Serializable;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
@ -161,10 +165,13 @@ public class DoSFilter implements Filter
static final String ENABLED_INIT_PARAM = "enabled";
static final String TOO_MANY_CODE = "tooManyCode";
private static final int USER_AUTH = 2;
private static final int USER_SESSION = 2;
private static final int USER_IP = 1;
private static final int USER_UNKNOWN = 0;
public enum RateType
{
AUTH,
SESSION,
IP,
UNKNOWN
}
private final String _suspended = "DoSFilter@" + Integer.toHexString(hashCode()) + ".SUSPENDED";
private final String _resumed = "DoSFilter@" + Integer.toHexString(hashCode()) + ".RESUMED";
@ -181,23 +188,22 @@ public class DoSFilter implements Filter
private volatile boolean _remotePort;
private volatile boolean _enabled;
private volatile String _name;
private DoSFilter.Listener _listener = new Listener();
private Semaphore _passes;
private volatile int _throttledRequests;
private volatile int _maxRequestsPerSec;
private Queue<AsyncContext>[] _queues;
private AsyncListener[] _listeners;
private Map<RateType, Queue<AsyncContext>> _queues = new HashMap<>();
private Map<RateType, AsyncListener> _listeners = new HashMap<>();
private Scheduler _scheduler;
private ServletContext _context;
@Override
public void init(FilterConfig filterConfig) throws ServletException
{
_queues = new Queue[getMaxPriority() + 1];
_listeners = new AsyncListener[_queues.length];
for (int p = 0; p < _queues.length; p++)
for (RateType rateType : RateType.values())
{
_queues[p] = new ConcurrentLinkedQueue<>();
_listeners[p] = new DoSAsyncListener(p);
_queues.put(rateType, new ConcurrentLinkedQueue<>());
_listeners.put(rateType, new DoSAsyncListener(rateType));
}
_rateTrackers.clear();
@ -305,67 +311,76 @@ public class DoSFilter implements Filter
// Look for the rate tracker for this request.
RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
if (tracker == null)
if (tracker != null)
{
// This is the first time we have seen this request.
if (LOG.isDebugEnabled())
LOG.debug("Filtering {}", request);
// Get a rate tracker associated with this request, and record one hit.
tracker = getRateTracker(request);
// Calculate the rate and check if it is over the allowed limit
final boolean overRateLimit = tracker.isRateExceeded(System.currentTimeMillis());
// Pass it through if we are not currently over the rate limit.
if (!overRateLimit)
{
if (LOG.isDebugEnabled())
LOG.debug("Allowing {}", request);
doFilterChain(filterChain, request, response);
return;
}
// We are over the limit.
// So either reject it, delay it or throttle it.
long delayMs = getDelayMs();
boolean insertHeaders = isInsertHeaders();
switch ((int)delayMs)
{
case -1:
{
// Reject this request.
LOG.warn("DOS ALERT: Request rejected ip={}, session={}, user={}", request.getRemoteAddr(), request.getRequestedSessionId(), request.getUserPrincipal());
if (insertHeaders)
response.addHeader("DoSFilter", "unavailable");
response.sendError(getTooManyCode());
return;
}
case 0:
{
// Fall through to throttle the request.
LOG.warn("DOS ALERT: Request throttled ip={}, session={}, user={}", request.getRemoteAddr(), request.getRequestedSessionId(), request.getUserPrincipal());
request.setAttribute(__TRACKER, tracker);
break;
}
default:
{
// Insert a delay before throttling the request,
// using the suspend+timeout mechanism of AsyncContext.
LOG.warn("DOS ALERT: Request delayed={}ms, ip={}, session={}, user={}", delayMs, request.getRemoteAddr(), request.getRequestedSessionId(), request.getUserPrincipal());
if (insertHeaders)
response.addHeader("DoSFilter", "delayed");
request.setAttribute(__TRACKER, tracker);
AsyncContext asyncContext = request.startAsync();
if (delayMs > 0)
asyncContext.setTimeout(delayMs);
asyncContext.addListener(new DoSTimeoutAsyncListener());
return;
}
}
// Redispatched, RateTracker present in request attributes.
throttleRequest(request, response, filterChain, tracker);
return;
}
// This is the first time we have seen this request.
if (LOG.isDebugEnabled())
LOG.debug("Filtering {}", request);
// Get a rate tracker associated with this request, and record one hit.
tracker = getRateTracker(request);
// Calculate the rate and check if it is over the allowed limit
final OverLimit overLimit = tracker.isRateExceeded(System.currentTimeMillis());
// Pass it through if we are not currently over the rate limit.
if (overLimit == null)
{
if (LOG.isDebugEnabled())
LOG.debug("Allowing {}", request);
doFilterChain(filterChain, request, response);
return;
}
// We are over the limit.
// Ask listener what to perform.
Action action = _listener.onRequestOverLimit(request, overLimit, this);
// Perform action
long delayMs = getDelayMs();
boolean insertHeaders = isInsertHeaders();
switch (action)
{
case NO_ACTION:
if (LOG.isDebugEnabled())
LOG.debug("Allowing over-limit request {}", request);
doFilterChain(filterChain, request, response);
break;
case ABORT:
if (LOG.isDebugEnabled())
LOG.debug("Aborting over-limit request {}", request);
response.sendError(-1);
return;
case REJECT:
if (insertHeaders)
response.addHeader("DoSFilter", "unavailable");
response.sendError(getTooManyCode());
return;
case DELAY:
// Insert a delay before throttling the request,
// using the suspend+timeout mechanism of AsyncContext.
if (insertHeaders)
response.addHeader("DoSFilter", "delayed");
request.setAttribute(__TRACKER, tracker);
AsyncContext asyncContext = request.startAsync();
if (delayMs > 0)
asyncContext.setTimeout(delayMs);
asyncContext.addListener(new DoSTimeoutAsyncListener());
break;
case THROTTLE:
throttleRequest(request, response, filterChain, tracker);
break;
}
}
private void throttleRequest(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain, RateTracker tracker) throws IOException, ServletException
{
if (LOG.isDebugEnabled())
LOG.debug("Throttling {}", request);
@ -383,15 +398,15 @@ public class DoSFilter implements Filter
long throttleMs = getThrottleMs();
if (!Boolean.TRUE.equals(throttled) && throttleMs > 0)
{
int priority = getPriority(request, tracker);
RateType priority = getPriority(request, tracker);
request.setAttribute(__THROTTLED, Boolean.TRUE);
if (isInsertHeaders())
response.addHeader("DoSFilter", "throttled");
AsyncContext asyncContext = request.startAsync();
request.setAttribute(_suspended, Boolean.TRUE);
asyncContext.setTimeout(throttleMs);
asyncContext.addListener(_listeners[priority]);
_queues[priority].add(asyncContext);
asyncContext.addListener(_listeners.get(priority));
_queues.get(priority).add(asyncContext);
if (LOG.isDebugEnabled())
LOG.debug("Throttled {}, {}ms", request, throttleMs);
return;
@ -436,9 +451,9 @@ public class DoSFilter implements Filter
try
{
// Wake up the next highest priority request.
for (int p = _queues.length - 1; p >= 0; --p)
for (RateType rateType : RateType.values())
{
AsyncContext asyncContext = _queues[p].poll();
AsyncContext asyncContext = _queues.get(rateType).poll();
if (asyncContext != null)
{
ServletRequest candidate = asyncContext.getRequest();
@ -530,21 +545,31 @@ public class DoSFilter implements Filter
* @param tracker the rate tracker for this request
* @return the priority for this request
*/
private int getPriority(HttpServletRequest request, RateTracker tracker)
private RateType getPriority(HttpServletRequest request, RateTracker tracker)
{
if (extractUserId(request) != null)
return USER_AUTH;
return RateType.AUTH;
if (tracker != null)
return tracker.getType();
return USER_UNKNOWN;
return RateType.UNKNOWN;
}
/**
* @return the maximum priority that we can assign to a request
*/
protected int getMaxPriority()
protected RateType getMaxPriority()
{
return USER_AUTH;
return RateType.AUTH;
}
public void setListener(DoSFilter.Listener listener)
{
_listener = Objects.requireNonNull(listener, "Listener may not be null");
}
public DoSFilter.Listener getListener()
{
return _listener;
}
private void schedule(RateTracker tracker)
@ -573,22 +598,22 @@ public class DoSFilter implements Filter
HttpSession session = ((HttpServletRequest)request).getSession(false);
String loadId = extractUserId(request);
final int type;
final RateType type;
if (loadId != null)
{
type = USER_AUTH;
type = RateType.AUTH;
}
else
{
if (isTrackSessions() && session != null && !session.isNew())
{
loadId = session.getId();
type = USER_SESSION;
type = RateType.SESSION;
}
else
{
loadId = isRemotePort() ? createRemotePortId(request) : request.getRemoteAddr();
type = USER_IP;
type = RateType.IP;
}
}
@ -605,7 +630,7 @@ public class DoSFilter implements Filter
if (existing != null)
tracker = existing;
if (type == USER_IP)
if (type == RateType.IP)
{
// USER_IP expiration from _rateTrackers is handled by the _scheduler
_scheduler.schedule(tracker, getMaxIdleTrackerMs(), TimeUnit.MILLISECONDS);
@ -1062,6 +1087,11 @@ public class DoSFilter implements Filter
_enabled = enabled;
}
/**
* Status code for Rejected for too many requests.
*
* @return the configured status code (default: 429 - Too Many Requests)
*/
public int getTooManyCode()
{
return _tooManyCode;
@ -1150,6 +1180,13 @@ public class DoSFilter implements Filter
return _whitelist.remove(address);
}
private String createRemotePortId(ServletRequest request)
{
String addr = request.getRemoteAddr();
int port = request.getRemotePort();
return addr + ":" + port;
}
/**
* A RateTracker is associated with a connection, and stores request rate
* data.
@ -1161,17 +1198,19 @@ public class DoSFilter implements Filter
protected final String _filterName;
protected transient ServletContext _context;
protected final String _id;
protected final int _type;
protected final RateType _type;
protected final int _maxRequestsPerSecond;
protected final long[] _timestamps;
protected int _next;
public RateTracker(ServletContext context, String filterName, String id, int type, int maxRequestsPerSecond)
public RateTracker(ServletContext context, String filterName, String id, RateType type, int maxRequestsPerSecond)
{
_context = context;
_filterName = filterName;
_id = id;
_type = type;
_maxRequestsPerSecond = maxRequestsPerSecond;
_timestamps = new long[maxRequestsPerSecond];
_next = 0;
}
@ -1180,7 +1219,7 @@ public class DoSFilter implements Filter
* @param now the time now (in milliseconds)
* @return the current calculated request rate over the last second
*/
public boolean isRateExceeded(long now)
public OverLimit isRateExceeded(long now)
{
final long last;
synchronized (this)
@ -1190,7 +1229,17 @@ public class DoSFilter implements Filter
_next = (_next + 1) % _timestamps.length;
}
return last != 0 && (now - last) < 1000L;
if (last == 0)
{
return null;
}
long rate = (now - last);
if (rate < 1000L)
{
return new Overage(Duration.ofMillis(rate), _maxRequestsPerSecond);
}
return null;
}
public String getId()
@ -1198,7 +1247,7 @@ public class DoSFilter implements Filter
return _id;
}
public int getType()
public RateType getType()
{
return _type;
}
@ -1271,7 +1320,7 @@ public class DoSFilter implements Filter
{
if (_context == null)
{
LOG.warn("Unknkown context for rate tracker {}", this);
LOG.warn("Unknown context for rate tracker {}", this);
return;
}
@ -1297,17 +1346,66 @@ public class DoSFilter implements Filter
{
return "RateTracker/" + _id + "/" + _type;
}
public class Overage implements OverLimit
{
private final Duration duration;
private final long count;
public Overage(Duration dur, long count)
{
this.duration = dur;
this.count = count;
}
@Override
public RateType getRateType()
{
return _type;
}
@Override
public String getRateId()
{
return _id;
}
@Override
public Duration getDuration()
{
return duration;
}
@Override
public long getCount()
{
return count;
}
@Override
public String toString()
{
final StringBuilder sb = new StringBuilder(OverLimit.class.getSimpleName());
sb.append('@').append(Integer.toHexString(hashCode()));
sb.append("[type=").append(getRateType());
sb.append(", id=").append(getRateId());
sb.append(", duration=").append(duration);
sb.append(", count=").append(count);
sb.append(']');
return sb.toString();
}
}
}
private static class FixedRateTracker extends RateTracker
{
public FixedRateTracker(ServletContext context, String filterName, String id, int type, int numRecentRequestsTracked)
public FixedRateTracker(ServletContext context, String filterName, String id, RateType type, int numRecentRequestsTracked)
{
super(context, filterName, id, type, numRecentRequestsTracked);
}
@Override
public boolean isRateExceeded(long now)
public OverLimit isRateExceeded(long now)
{
// rate limit is never exceeded, but we keep track of the request timestamps
// so that we know whether there was recent activity on this tracker
@ -1318,7 +1416,7 @@ public class DoSFilter implements Filter
_next = (_next + 1) % _timestamps.length;
}
return false;
return null;
}
@Override
@ -1354,9 +1452,9 @@ public class DoSFilter implements Filter
private class DoSAsyncListener extends DoSTimeoutAsyncListener
{
private final int priority;
private final RateType priority;
public DoSAsyncListener(int priority)
public DoSAsyncListener(RateType priority)
{
this.priority = priority;
}
@ -1364,15 +1462,93 @@ public class DoSFilter implements Filter
@Override
public void onTimeout(AsyncEvent event) throws IOException
{
_queues[priority].remove(event.getAsyncContext());
_queues.get(priority).remove(event.getAsyncContext());
super.onTimeout(event);
}
}
private String createRemotePortId(ServletRequest request)
public enum Action
{
String addr = request.getRemoteAddr();
int port = request.getRemotePort();
return addr + ":" + port;
/**
* No action is taken against the Request, it is allowed to be processed normally.
*/
NO_ACTION,
/**
* The request and response is aborted, no response is sent.
*/
ABORT,
/**
* The request is rejected by sending an error based on {@link DoSFilter#getTooManyCode()}
*/
REJECT,
/**
* The request is delayed based on {@link DoSFilter#getDelayMs()}
*/
DELAY,
/**
* The request is throttled.
*/
THROTTLE;
/**
* Obtain the Action based on configured {@link DoSFilter#getDelayMs()}
*
* @param delayMs the delay in milliseconds.
* @return the Action proposed.
*/
public static Action fromDelay(long delayMs)
{
if (delayMs < 0)
return Action.REJECT;
if (delayMs == 0)
return Action.THROTTLE;
return Action.DELAY;
}
}
public interface OverLimit
{
RateType getRateType();
String getRateId();
Duration getDuration();
long getCount();
}
/**
* Listener for actions taken against specific requests.
*/
public static class Listener
{
/**
* Process the onRequestOverLimit() behavior.
*
* @param request the request that is over the limit
* @param dosFilter the {@link DoSFilter} that this event occurred on
* @return the action to actually perform.
*/
public Action onRequestOverLimit(HttpServletRequest request, OverLimit overlimit, DoSFilter dosFilter)
{
Action action = Action.fromDelay(dosFilter.getDelayMs());
switch (action)
{
case REJECT:
LOG.warn("DOS ALERT: Request rejected ip={}, overlimit={}, session={}, user={}", request.getRemoteAddr(), overlimit, request.getRequestedSessionId(), request.getUserPrincipal());
break;
case DELAY:
LOG.warn("DOS ALERT: Request delayed={}ms, ip={}, overlimit={}, session={}, user={}", dosFilter.getDelayMs(), request.getRemoteAddr(), overlimit, request.getRequestedSessionId(), request.getUserPrincipal());
break;
case THROTTLE:
LOG.warn("DOS ALERT: Request throttled ip={}, overlimit={}, session={}, user={}", request.getRemoteAddr(), overlimit, request.getRequestedSessionId(), request.getUserPrincipal());
break;
}
return action;
}
}
}

View File

@ -174,12 +174,12 @@ public class DoSFilterTest extends AbstractDoSFilterTest
{
boolean exceeded = false;
ServletContext context = new ContextHandler.StaticContext();
RateTracker rateTracker = new RateTracker(context, doSFilter.getName(), "test2", 0, 4);
RateTracker rateTracker = new RateTracker(context, doSFilter.getName(), "test2", DoSFilter.RateType.UNKNOWN, 4);
for (int i = 0; i < 5; i++)
{
Thread.sleep(sleep);
if (rateTracker.isRateExceeded(TimeUnit.NANOSECONDS.toMillis(System.nanoTime())))
if (rateTracker.isRateExceeded(TimeUnit.NANOSECONDS.toMillis(System.nanoTime())) != null)
exceeded = true;
}
return exceeded;