Merge pull request #5195 from eclipse/jetty-9.4.x-5185-dosfilter-listener

Issue #5185 - Add DoSFilter Listener to allow extensible behavior
This commit is contained in:
Joakim Erdfelt 2020-09-03 14:10:29 -05:00 committed by GitHub
commit be86e66e77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 278 additions and 102 deletions

View File

@ -20,9 +20,13 @@ package org.eclipse.jetty.servlets;
import java.io.IOException; import java.io.IOException;
import java.io.Serializable; import java.io.Serializable;
import java.time.Duration;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Queue; import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ConcurrentLinkedQueue;
@ -161,10 +165,13 @@ public class DoSFilter implements Filter
static final String ENABLED_INIT_PARAM = "enabled"; static final String ENABLED_INIT_PARAM = "enabled";
static final String TOO_MANY_CODE = "tooManyCode"; static final String TOO_MANY_CODE = "tooManyCode";
private static final int USER_AUTH = 2; public enum RateType
private static final int USER_SESSION = 2; {
private static final int USER_IP = 1; AUTH,
private static final int USER_UNKNOWN = 0; SESSION,
IP,
UNKNOWN
}
private final String _suspended = "DoSFilter@" + Integer.toHexString(hashCode()) + ".SUSPENDED"; private final String _suspended = "DoSFilter@" + Integer.toHexString(hashCode()) + ".SUSPENDED";
private final String _resumed = "DoSFilter@" + Integer.toHexString(hashCode()) + ".RESUMED"; private final String _resumed = "DoSFilter@" + Integer.toHexString(hashCode()) + ".RESUMED";
@ -181,23 +188,22 @@ public class DoSFilter implements Filter
private volatile boolean _remotePort; private volatile boolean _remotePort;
private volatile boolean _enabled; private volatile boolean _enabled;
private volatile String _name; private volatile String _name;
private DoSFilter.Listener _listener = new Listener();
private Semaphore _passes; private Semaphore _passes;
private volatile int _throttledRequests; private volatile int _throttledRequests;
private volatile int _maxRequestsPerSec; private volatile int _maxRequestsPerSec;
private Queue<AsyncContext>[] _queues; private Map<RateType, Queue<AsyncContext>> _queues = new HashMap<>();
private AsyncListener[] _listeners; private Map<RateType, AsyncListener> _listeners = new HashMap<>();
private Scheduler _scheduler; private Scheduler _scheduler;
private ServletContext _context; private ServletContext _context;
@Override @Override
public void init(FilterConfig filterConfig) throws ServletException public void init(FilterConfig filterConfig) throws ServletException
{ {
_queues = new Queue[getMaxPriority() + 1]; for (RateType rateType : RateType.values())
_listeners = new AsyncListener[_queues.length];
for (int p = 0; p < _queues.length; p++)
{ {
_queues[p] = new ConcurrentLinkedQueue<>(); _queues.put(rateType, new ConcurrentLinkedQueue<>());
_listeners[p] = new DoSAsyncListener(p); _listeners.put(rateType, new DoSAsyncListener(rateType));
} }
_rateTrackers.clear(); _rateTrackers.clear();
@ -305,8 +311,13 @@ public class DoSFilter implements Filter
// Look for the rate tracker for this request. // Look for the rate tracker for this request.
RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER); RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
if (tracker == null) if (tracker != null)
{ {
// Redispatched, RateTracker present in request attributes.
throttleRequest(request, response, filterChain, tracker);
return;
}
// This is the first time we have seen this request. // This is the first time we have seen this request.
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("Filtering {}", request); LOG.debug("Filtering {}", request);
@ -315,10 +326,10 @@ public class DoSFilter implements Filter
tracker = getRateTracker(request); tracker = getRateTracker(request);
// Calculate the rate and check if it is over the allowed limit // Calculate the rate and check if it is over the allowed limit
final boolean overRateLimit = tracker.isRateExceeded(System.currentTimeMillis()); final OverLimit overLimit = tracker.isRateExceeded(System.currentTimeMillis());
// Pass it through if we are not currently over the rate limit. // Pass it through if we are not currently over the rate limit.
if (!overRateLimit) if (overLimit == null)
{ {
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("Allowing {}", request); LOG.debug("Allowing {}", request);
@ -328,32 +339,32 @@ public class DoSFilter implements Filter
// We are over the limit. // We are over the limit.
// So either reject it, delay it or throttle it. // Ask listener what to perform.
Action action = _listener.onRequestOverLimit(request, overLimit, this);
// Perform action
long delayMs = getDelayMs(); long delayMs = getDelayMs();
boolean insertHeaders = isInsertHeaders(); boolean insertHeaders = isInsertHeaders();
switch ((int)delayMs) switch (action)
{ {
case -1: case NO_ACTION:
{ if (LOG.isDebugEnabled())
// Reject this request. LOG.debug("Allowing over-limit request {}", request);
LOG.warn("DOS ALERT: Request rejected ip={}, session={}, user={}", request.getRemoteAddr(), request.getRequestedSessionId(), request.getUserPrincipal()); doFilterChain(filterChain, request, response);
break;
case ABORT:
if (LOG.isDebugEnabled())
LOG.debug("Aborting over-limit request {}", request);
response.sendError(-1);
return;
case REJECT:
if (insertHeaders) if (insertHeaders)
response.addHeader("DoSFilter", "unavailable"); response.addHeader("DoSFilter", "unavailable");
response.sendError(getTooManyCode()); response.sendError(getTooManyCode());
return; return;
} case DELAY:
case 0:
{
// Fall through to throttle the request.
LOG.warn("DOS ALERT: Request throttled ip={}, session={}, user={}", request.getRemoteAddr(), request.getRequestedSessionId(), request.getUserPrincipal());
request.setAttribute(__TRACKER, tracker);
break;
}
default:
{
// Insert a delay before throttling the request, // Insert a delay before throttling the request,
// using the suspend+timeout mechanism of AsyncContext. // using the suspend+timeout mechanism of AsyncContext.
LOG.warn("DOS ALERT: Request delayed={}ms, ip={}, session={}, user={}", delayMs, request.getRemoteAddr(), request.getRequestedSessionId(), request.getUserPrincipal());
if (insertHeaders) if (insertHeaders)
response.addHeader("DoSFilter", "delayed"); response.addHeader("DoSFilter", "delayed");
request.setAttribute(__TRACKER, tracker); request.setAttribute(__TRACKER, tracker);
@ -361,11 +372,15 @@ public class DoSFilter implements Filter
if (delayMs > 0) if (delayMs > 0)
asyncContext.setTimeout(delayMs); asyncContext.setTimeout(delayMs);
asyncContext.addListener(new DoSTimeoutAsyncListener()); asyncContext.addListener(new DoSTimeoutAsyncListener());
return; break;
} case THROTTLE:
throttleRequest(request, response, filterChain, tracker);
break;
} }
} }
private void throttleRequest(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain, RateTracker tracker) throws IOException, ServletException
{
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("Throttling {}", request); LOG.debug("Throttling {}", request);
@ -383,15 +398,15 @@ public class DoSFilter implements Filter
long throttleMs = getThrottleMs(); long throttleMs = getThrottleMs();
if (!Boolean.TRUE.equals(throttled) && throttleMs > 0) if (!Boolean.TRUE.equals(throttled) && throttleMs > 0)
{ {
int priority = getPriority(request, tracker); RateType priority = getPriority(request, tracker);
request.setAttribute(__THROTTLED, Boolean.TRUE); request.setAttribute(__THROTTLED, Boolean.TRUE);
if (isInsertHeaders()) if (isInsertHeaders())
response.addHeader("DoSFilter", "throttled"); response.addHeader("DoSFilter", "throttled");
AsyncContext asyncContext = request.startAsync(); AsyncContext asyncContext = request.startAsync();
request.setAttribute(_suspended, Boolean.TRUE); request.setAttribute(_suspended, Boolean.TRUE);
asyncContext.setTimeout(throttleMs); asyncContext.setTimeout(throttleMs);
asyncContext.addListener(_listeners[priority]); asyncContext.addListener(_listeners.get(priority));
_queues[priority].add(asyncContext); _queues.get(priority).add(asyncContext);
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("Throttled {}, {}ms", request, throttleMs); LOG.debug("Throttled {}, {}ms", request, throttleMs);
return; return;
@ -436,9 +451,9 @@ public class DoSFilter implements Filter
try try
{ {
// Wake up the next highest priority request. // Wake up the next highest priority request.
for (int p = _queues.length - 1; p >= 0; --p) for (RateType rateType : RateType.values())
{ {
AsyncContext asyncContext = _queues[p].poll(); AsyncContext asyncContext = _queues.get(rateType).poll();
if (asyncContext != null) if (asyncContext != null)
{ {
ServletRequest candidate = asyncContext.getRequest(); ServletRequest candidate = asyncContext.getRequest();
@ -530,21 +545,31 @@ public class DoSFilter implements Filter
* @param tracker the rate tracker for this request * @param tracker the rate tracker for this request
* @return the priority for this request * @return the priority for this request
*/ */
private int getPriority(HttpServletRequest request, RateTracker tracker) private RateType getPriority(HttpServletRequest request, RateTracker tracker)
{ {
if (extractUserId(request) != null) if (extractUserId(request) != null)
return USER_AUTH; return RateType.AUTH;
if (tracker != null) if (tracker != null)
return tracker.getType(); return tracker.getType();
return USER_UNKNOWN; return RateType.UNKNOWN;
} }
/** /**
* @return the maximum priority that we can assign to a request * @return the maximum priority that we can assign to a request
*/ */
protected int getMaxPriority() protected RateType getMaxPriority()
{ {
return USER_AUTH; return RateType.AUTH;
}
public void setListener(DoSFilter.Listener listener)
{
_listener = Objects.requireNonNull(listener, "Listener may not be null");
}
public DoSFilter.Listener getListener()
{
return _listener;
} }
private void schedule(RateTracker tracker) private void schedule(RateTracker tracker)
@ -573,22 +598,22 @@ public class DoSFilter implements Filter
HttpSession session = ((HttpServletRequest)request).getSession(false); HttpSession session = ((HttpServletRequest)request).getSession(false);
String loadId = extractUserId(request); String loadId = extractUserId(request);
final int type; final RateType type;
if (loadId != null) if (loadId != null)
{ {
type = USER_AUTH; type = RateType.AUTH;
} }
else else
{ {
if (isTrackSessions() && session != null && !session.isNew()) if (isTrackSessions() && session != null && !session.isNew())
{ {
loadId = session.getId(); loadId = session.getId();
type = USER_SESSION; type = RateType.SESSION;
} }
else else
{ {
loadId = isRemotePort() ? createRemotePortId(request) : request.getRemoteAddr(); loadId = isRemotePort() ? createRemotePortId(request) : request.getRemoteAddr();
type = USER_IP; type = RateType.IP;
} }
} }
@ -605,7 +630,7 @@ public class DoSFilter implements Filter
if (existing != null) if (existing != null)
tracker = existing; tracker = existing;
if (type == USER_IP) if (type == RateType.IP)
{ {
// USER_IP expiration from _rateTrackers is handled by the _scheduler // USER_IP expiration from _rateTrackers is handled by the _scheduler
_scheduler.schedule(tracker, getMaxIdleTrackerMs(), TimeUnit.MILLISECONDS); _scheduler.schedule(tracker, getMaxIdleTrackerMs(), TimeUnit.MILLISECONDS);
@ -1062,6 +1087,11 @@ public class DoSFilter implements Filter
_enabled = enabled; _enabled = enabled;
} }
/**
* Status code for Rejected for too many requests.
*
* @return the configured status code (default: 429 - Too Many Requests)
*/
public int getTooManyCode() public int getTooManyCode()
{ {
return _tooManyCode; return _tooManyCode;
@ -1150,6 +1180,13 @@ public class DoSFilter implements Filter
return _whitelist.remove(address); return _whitelist.remove(address);
} }
private String createRemotePortId(ServletRequest request)
{
String addr = request.getRemoteAddr();
int port = request.getRemotePort();
return addr + ":" + port;
}
/** /**
* A RateTracker is associated with a connection, and stores request rate * A RateTracker is associated with a connection, and stores request rate
* data. * data.
@ -1161,17 +1198,19 @@ public class DoSFilter implements Filter
protected final String _filterName; protected final String _filterName;
protected transient ServletContext _context; protected transient ServletContext _context;
protected final String _id; protected final String _id;
protected final int _type; protected final RateType _type;
protected final int _maxRequestsPerSecond;
protected final long[] _timestamps; protected final long[] _timestamps;
protected int _next; protected int _next;
public RateTracker(ServletContext context, String filterName, String id, int type, int maxRequestsPerSecond) public RateTracker(ServletContext context, String filterName, String id, RateType type, int maxRequestsPerSecond)
{ {
_context = context; _context = context;
_filterName = filterName; _filterName = filterName;
_id = id; _id = id;
_type = type; _type = type;
_maxRequestsPerSecond = maxRequestsPerSecond;
_timestamps = new long[maxRequestsPerSecond]; _timestamps = new long[maxRequestsPerSecond];
_next = 0; _next = 0;
} }
@ -1180,7 +1219,7 @@ public class DoSFilter implements Filter
* @param now the time now (in milliseconds) * @param now the time now (in milliseconds)
* @return the current calculated request rate over the last second * @return the current calculated request rate over the last second
*/ */
public boolean isRateExceeded(long now) public OverLimit isRateExceeded(long now)
{ {
final long last; final long last;
synchronized (this) synchronized (this)
@ -1190,7 +1229,17 @@ public class DoSFilter implements Filter
_next = (_next + 1) % _timestamps.length; _next = (_next + 1) % _timestamps.length;
} }
return last != 0 && (now - last) < 1000L; if (last == 0)
{
return null;
}
long rate = (now - last);
if (rate < 1000L)
{
return new Overage(Duration.ofMillis(rate), _maxRequestsPerSecond);
}
return null;
} }
public String getId() public String getId()
@ -1198,7 +1247,7 @@ public class DoSFilter implements Filter
return _id; return _id;
} }
public int getType() public RateType getType()
{ {
return _type; return _type;
} }
@ -1271,7 +1320,7 @@ public class DoSFilter implements Filter
{ {
if (_context == null) if (_context == null)
{ {
LOG.warn("Unknkown context for rate tracker {}", this); LOG.warn("Unknown context for rate tracker {}", this);
return; return;
} }
@ -1297,17 +1346,66 @@ public class DoSFilter implements Filter
{ {
return "RateTracker/" + _id + "/" + _type; return "RateTracker/" + _id + "/" + _type;
} }
public class Overage implements OverLimit
{
private final Duration duration;
private final long count;
public Overage(Duration dur, long count)
{
this.duration = dur;
this.count = count;
}
@Override
public RateType getRateType()
{
return _type;
}
@Override
public String getRateId()
{
return _id;
}
@Override
public Duration getDuration()
{
return duration;
}
@Override
public long getCount()
{
return count;
}
@Override
public String toString()
{
final StringBuilder sb = new StringBuilder(OverLimit.class.getSimpleName());
sb.append('@').append(Integer.toHexString(hashCode()));
sb.append("[type=").append(getRateType());
sb.append(", id=").append(getRateId());
sb.append(", duration=").append(duration);
sb.append(", count=").append(count);
sb.append(']');
return sb.toString();
}
}
} }
private static class FixedRateTracker extends RateTracker private static class FixedRateTracker extends RateTracker
{ {
public FixedRateTracker(ServletContext context, String filterName, String id, int type, int numRecentRequestsTracked) public FixedRateTracker(ServletContext context, String filterName, String id, RateType type, int numRecentRequestsTracked)
{ {
super(context, filterName, id, type, numRecentRequestsTracked); super(context, filterName, id, type, numRecentRequestsTracked);
} }
@Override @Override
public boolean isRateExceeded(long now) public OverLimit isRateExceeded(long now)
{ {
// rate limit is never exceeded, but we keep track of the request timestamps // rate limit is never exceeded, but we keep track of the request timestamps
// so that we know whether there was recent activity on this tracker // so that we know whether there was recent activity on this tracker
@ -1318,7 +1416,7 @@ public class DoSFilter implements Filter
_next = (_next + 1) % _timestamps.length; _next = (_next + 1) % _timestamps.length;
} }
return false; return null;
} }
@Override @Override
@ -1354,9 +1452,9 @@ public class DoSFilter implements Filter
private class DoSAsyncListener extends DoSTimeoutAsyncListener private class DoSAsyncListener extends DoSTimeoutAsyncListener
{ {
private final int priority; private final RateType priority;
public DoSAsyncListener(int priority) public DoSAsyncListener(RateType priority)
{ {
this.priority = priority; this.priority = priority;
} }
@ -1364,15 +1462,93 @@ public class DoSFilter implements Filter
@Override @Override
public void onTimeout(AsyncEvent event) throws IOException public void onTimeout(AsyncEvent event) throws IOException
{ {
_queues[priority].remove(event.getAsyncContext()); _queues.get(priority).remove(event.getAsyncContext());
super.onTimeout(event); super.onTimeout(event);
} }
} }
private String createRemotePortId(ServletRequest request) public enum Action
{ {
String addr = request.getRemoteAddr(); /**
int port = request.getRemotePort(); * No action is taken against the Request, it is allowed to be processed normally.
return addr + ":" + port; */
NO_ACTION,
/**
* The request and response is aborted, no response is sent.
*/
ABORT,
/**
* The request is rejected by sending an error based on {@link DoSFilter#getTooManyCode()}
*/
REJECT,
/**
* The request is delayed based on {@link DoSFilter#getDelayMs()}
*/
DELAY,
/**
* The request is throttled.
*/
THROTTLE;
/**
* Obtain the Action based on configured {@link DoSFilter#getDelayMs()}
*
* @param delayMs the delay in milliseconds.
* @return the Action proposed.
*/
public static Action fromDelay(long delayMs)
{
if (delayMs < 0)
return Action.REJECT;
if (delayMs == 0)
return Action.THROTTLE;
return Action.DELAY;
}
}
public interface OverLimit
{
RateType getRateType();
String getRateId();
Duration getDuration();
long getCount();
}
/**
* Listener for actions taken against specific requests.
*/
public static class Listener
{
/**
* Process the onRequestOverLimit() behavior.
*
* @param request the request that is over the limit
* @param dosFilter the {@link DoSFilter} that this event occurred on
* @return the action to actually perform.
*/
public Action onRequestOverLimit(HttpServletRequest request, OverLimit overlimit, DoSFilter dosFilter)
{
Action action = Action.fromDelay(dosFilter.getDelayMs());
switch (action)
{
case REJECT:
LOG.warn("DOS ALERT: Request rejected ip={}, overlimit={}, session={}, user={}", request.getRemoteAddr(), overlimit, request.getRequestedSessionId(), request.getUserPrincipal());
break;
case DELAY:
LOG.warn("DOS ALERT: Request delayed={}ms, ip={}, overlimit={}, session={}, user={}", dosFilter.getDelayMs(), request.getRemoteAddr(), overlimit, request.getRequestedSessionId(), request.getUserPrincipal());
break;
case THROTTLE:
LOG.warn("DOS ALERT: Request throttled ip={}, overlimit={}, session={}, user={}", request.getRemoteAddr(), overlimit, request.getRequestedSessionId(), request.getUserPrincipal());
break;
}
return action;
}
} }
} }

View File

@ -174,12 +174,12 @@ public class DoSFilterTest extends AbstractDoSFilterTest
{ {
boolean exceeded = false; boolean exceeded = false;
ServletContext context = new ContextHandler.StaticContext(); ServletContext context = new ContextHandler.StaticContext();
RateTracker rateTracker = new RateTracker(context, doSFilter.getName(), "test2", 0, 4); RateTracker rateTracker = new RateTracker(context, doSFilter.getName(), "test2", DoSFilter.RateType.UNKNOWN, 4);
for (int i = 0; i < 5; i++) for (int i = 0; i < 5; i++)
{ {
Thread.sleep(sleep); Thread.sleep(sleep);
if (rateTracker.isRateExceeded(TimeUnit.NANOSECONDS.toMillis(System.nanoTime()))) if (rateTracker.isRateExceeded(TimeUnit.NANOSECONDS.toMillis(System.nanoTime())) != null)
exceeded = true; exceeded = true;
} }
return exceeded; return exceeded;