Simplify the DosFilter for #1256 (#10748)

Use only IP tracking for the DosFilter to fix #1256

Signed-off-by: gregw <gregw@webtide.com>
This commit is contained in:
Greg Wilkins 2023-10-18 22:34:54 +02:00 committed by GitHub
parent 8c94490e18
commit 168d8715d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 61 additions and 248 deletions

View File

@ -32,9 +32,7 @@ The filter works on the assumption that the attacker might be written in simple
[[dos-filter-using]]
==== Using the DoS Filter
Jetty places throttled requests in a priority queue, giving priority first to authenticated users and users with an HttpSession, then to connections identified by their IP addresses.
Connections with no way to identify them have lowest priority.
To uniquely identify authenticated users, you should implement the The extractUserId(ServletRequest request) function.
Jetty places throttled requests in a queue, and proceed only when there is capacity available.
===== Required JARs
@ -94,11 +92,8 @@ Default is 30000L.
insertHeaders::
If true, insert the DoSFilter headers into the response.
Defaults to true.
trackSessions::
If true, usage rate is tracked by session if a session exists.
Defaults to true.
remotePort::
If true and session tracking is not used, then rate is tracked by IP and port (effectively connection).
If true, then rate is tracked by IP and port (effectively connection).
Defaults to false.
ipWhitelist::
A comma-separated list of IP addresses that will not be rate limited.

View File

@ -17,10 +17,8 @@ import java.io.IOException;
import java.io.Serializable;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
@ -43,11 +41,6 @@ import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import javax.servlet.http.HttpSessionActivationListener;
import javax.servlet.http.HttpSessionBindingEvent;
import javax.servlet.http.HttpSessionBindingListener;
import javax.servlet.http.HttpSessionEvent;
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.server.handler.ContextHandler;
@ -74,10 +67,8 @@ import org.slf4j.LoggerFactory;
* second. If a limit is exceeded, the request is either rejected, delayed, or
* throttled.
* <p>
* When a request is throttled, it is placed in a priority queue. Priority is
* given first to authenticated users and users with an HttpSession, then
* connections which can be identified by their IP addresses. Connections with
* no way to identify them are given lowest priority.
* When a request is throttled, it is placed in a queue and will only proceed
* when there is capacity.
* <p>
* The {@link #extractUserId(ServletRequest request)} function should be
* implemented, in order to uniquely identify authenticated users.
@ -106,10 +97,8 @@ import org.slf4j.LoggerFactory;
* before deciding that the user has gone away, and discarding it</dd>
* <dt>insertHeaders</dt>
* <dd>if true , insert the DoSFilter headers into the response. Defaults to true.</dd>
* <dt>trackSessions</dt>
* <dd>if true, usage rate is tracked by session if a session exists. Defaults to true.</dd>
* <dt>remotePort</dt>
* <dd>if true and session tracking is not used, then rate is tracked by IP+port (effectively connection). Defaults to false.</dd>
* <dd>if true then rate is tracked by IP+port (effectively connection). Defaults to false.</dd>
* <dt>ipWhitelist</dt>
* <dd>a comma-separated list of IP addresses that will not be rate limited</dd>
* <dt>managedAttr</dt>
@ -156,12 +145,14 @@ public class DoSFilter implements Filter
static final String MAX_REQUEST_MS_INIT_PARAM = "maxRequestMs";
static final String MAX_IDLE_TRACKER_MS_INIT_PARAM = "maxIdleTrackerMs";
static final String INSERT_HEADERS_INIT_PARAM = "insertHeaders";
@Deprecated
static final String TRACK_SESSIONS_INIT_PARAM = "trackSessions";
static final String REMOTE_PORT_INIT_PARAM = "remotePort";
static final String IP_WHITELIST_INIT_PARAM = "ipWhitelist";
static final String ENABLED_INIT_PARAM = "enabled";
static final String TOO_MANY_CODE = "tooManyCode";
@Deprecated
public enum RateType
{
AUTH,
@ -181,7 +172,6 @@ public class DoSFilter implements Filter
private volatile long _maxRequestMs;
private volatile long _maxIdleTrackerMs;
private volatile boolean _insertHeaders;
private volatile boolean _trackSessions;
private volatile boolean _remotePort;
private volatile boolean _enabled;
private volatile String _name;
@ -189,20 +179,14 @@ public class DoSFilter implements Filter
private Semaphore _passes;
private volatile int _throttledRequests;
private volatile int _maxRequestsPerSec;
private Map<RateType, Queue<AsyncContext>> _queues = new HashMap<>();
private Map<RateType, AsyncListener> _listeners = new HashMap<>();
private final Queue<AsyncContext> _queue = new ConcurrentLinkedQueue<>();
private final AsyncListener _asyncListener = new DoSAsyncListener();
private Scheduler _scheduler;
private ServletContext _context;
@Override
public void init(FilterConfig filterConfig) throws ServletException
{
for (RateType rateType : RateType.values())
{
_queues.put(rateType, new ConcurrentLinkedQueue<>());
_listeners.put(rateType, new DoSAsyncListener(rateType));
}
_rateTrackers.clear();
int maxRequests = __DEFAULT_MAX_REQUESTS_PER_SEC;
@ -395,15 +379,14 @@ public class DoSFilter implements Filter
long throttleMs = getThrottleMs();
if (!Boolean.TRUE.equals(throttled) && throttleMs > 0)
{
RateType priority = getPriority(request, tracker);
request.setAttribute(__THROTTLED, Boolean.TRUE);
if (isInsertHeaders())
response.addHeader("DoSFilter", "throttled");
AsyncContext asyncContext = request.startAsync();
request.setAttribute(_suspended, Boolean.TRUE);
asyncContext.setTimeout(throttleMs);
asyncContext.addListener(_listeners.get(priority));
_queues.get(priority).add(asyncContext);
asyncContext.addListener(_asyncListener);
_queue.add(asyncContext);
if (LOG.isDebugEnabled())
LOG.debug("Throttled {}, {}ms", request, throttleMs);
return;
@ -447,22 +430,17 @@ public class DoSFilter implements Filter
{
try
{
// Wake up the next highest priority request.
for (RateType rateType : RateType.values())
AsyncContext asyncContext = _queue.poll();
if (asyncContext != null)
{
AsyncContext asyncContext = _queues.get(rateType).poll();
if (asyncContext != null)
ServletRequest candidate = asyncContext.getRequest();
Boolean suspended = (Boolean)candidate.getAttribute(_suspended);
if (Boolean.TRUE.equals(suspended))
{
ServletRequest candidate = asyncContext.getRequest();
Boolean suspended = (Boolean)candidate.getAttribute(_suspended);
if (Boolean.TRUE.equals(suspended))
{
if (LOG.isDebugEnabled())
LOG.debug("Resuming {}", request);
candidate.setAttribute(_resumed, Boolean.TRUE);
asyncContext.dispatch();
break;
}
if (LOG.isDebugEnabled())
LOG.debug("Resuming {}", request);
candidate.setAttribute(_resumed, Boolean.TRUE);
asyncContext.dispatch();
}
}
}
@ -524,27 +502,13 @@ public class DoSFilter implements Filter
}
/**
* Get priority for this request, based on user type
*
* @param request the current request
* @param tracker the rate tracker for this request
* @return the priority for this request
*/
private RateType getPriority(HttpServletRequest request, RateTracker tracker)
{
if (extractUserId(request) != null)
return RateType.AUTH;
if (tracker != null)
return tracker.getType();
return RateType.UNKNOWN;
}
/**
* @return the maximum priority that we can assign to a request
* @return null
* @deprecated Priority no longer supported
*/
@Deprecated
protected RateType getMaxPriority()
{
return RateType.AUTH;
return null;
}
public void setListener(DoSFilter.Listener listener)
@ -570,61 +534,29 @@ public class DoSFilter implements Filter
* <p>
* Assumes that each connection has an identifying characteristic, and goes
* through them in order, taking the first that matches: user id (logged
* in), session id, client IP address. Unidentifiable connections are lumped
* in), client IP address. Unidentifiable connections are lumped
* into one.
* <p>
* When a session expires, its rate tracker is automatically deleted.
*
* @param request the current request
* @return the request rate tracker for the current connection
*/
RateTracker getRateTracker(ServletRequest request)
{
HttpSession session = ((HttpServletRequest)request).getSession(false);
String loadId = extractUserId(request);
final RateType type;
if (loadId != null)
{
type = RateType.AUTH;
}
else
{
if (isTrackSessions() && session != null && !session.isNew())
{
loadId = session.getId();
type = RateType.SESSION;
}
else
{
loadId = isRemotePort() ? createRemotePortId(request) : request.getRemoteAddr();
type = RateType.IP;
}
}
String loadId = isRemotePort() ? createRemotePortId(request) : request.getRemoteAddr();
RateTracker tracker = _rateTrackers.get(loadId);
if (tracker == null)
{
boolean allowed = checkWhitelist(request.getRemoteAddr());
int maxRequestsPerSec = getMaxRequestsPerSec();
tracker = allowed ? new FixedRateTracker(_context, _name, loadId, type, maxRequestsPerSec)
: new RateTracker(_context, _name, loadId, type, maxRequestsPerSec);
tracker = allowed ? new FixedRateTracker(_context, _name, loadId, maxRequestsPerSec)
: new RateTracker(_context, _name, loadId, maxRequestsPerSec);
tracker.setContext(_context);
RateTracker existing = _rateTrackers.putIfAbsent(loadId, tracker);
if (existing != null)
tracker = existing;
if (type == RateType.IP)
{
// USER_IP expiration from _rateTrackers is handled by the _scheduler
_scheduler.schedule(tracker, getMaxIdleTrackerMs(), TimeUnit.MILLISECONDS);
}
else if (session != null)
{
// USER_SESSION expiration from _rateTrackers are handled by the HttpSessionBindingListener
session.setAttribute(__TRACKER, tracker);
}
_scheduler.schedule(tracker, getMaxIdleTrackerMs(), TimeUnit.MILLISECONDS);
}
return tracker;
@ -750,7 +682,7 @@ public class DoSFilter implements Filter
return result;
// Sets the _prefix_ most significant bits to 1
result[index] = (byte)~((1 << (8 - prefix)) - 1);
result[index] = (byte)-(1 << (8 - prefix));
return result;
}
@ -776,12 +708,11 @@ public class DoSFilter implements Filter
}
/**
* Returns the user id, used to track this connection.
* This SHOULD be overridden by subclasses.
*
* @param request the current request
* @return a unique user id, if logged in; otherwise null.
* @param request ignored
* @return null
* @deprecated User ID no longer supported
*/
@Deprecated
protected String extractUserId(ServletRequest request)
{
return null;
@ -996,26 +927,29 @@ public class DoSFilter implements Filter
* Get flag to have usage rate tracked by session if a session exists.
*
* @return value of the flag
* @deprecated Session tracking is no longer supported
*/
@ManagedAttribute("usage rate is tracked by session if one exists")
@Deprecated
public boolean isTrackSessions()
{
return _trackSessions;
return false;
}
/**
* Set flag to have usage rate tracked by session if a session exists.
*
* @param value value of the flag
* @deprecated Session tracking is no longer supported
*/
@Deprecated
public void setTrackSessions(boolean value)
{
_trackSessions = value;
if (value)
LOG.warn("Session Tracking is not supported");
}
/**
* Get flag to have usage rate tracked by IP+port (effectively connection)
* if session tracking is not used.
*
* @return value of the flag
*/
@ -1027,7 +961,6 @@ public class DoSFilter implements Filter
/**
* Set flag to have usage rate tracked by IP+port (effectively connection)
* if session tracking is not used.
*
* @param value value of the flag
*/
@ -1130,7 +1063,10 @@ public class DoSFilter implements Filter
private boolean addWhitelistAddress(List<String> list, String address)
{
address = address.trim();
return address.length() > 0 && list.add(address);
if (address.length() <= 0)
return false;
list.add(address);
return true;
}
/**
@ -1157,7 +1093,7 @@ public class DoSFilter implements Filter
* A RateTracker is associated with a connection, and stores request rate
* data.
*/
static class RateTracker implements Runnable, HttpSessionBindingListener, HttpSessionActivationListener, Serializable
static class RateTracker implements Runnable, Serializable
{
private static final long serialVersionUID = 3534663738034577872L;
@ -1165,18 +1101,16 @@ public class DoSFilter implements Filter
protected final String _filterName;
protected transient ServletContext _context;
protected final String _id;
protected final RateType _type;
protected final int _maxRequestsPerSecond;
protected final long[] _timestamps;
protected int _next;
public RateTracker(ServletContext context, String filterName, String id, RateType type, int maxRequestsPerSecond)
RateTracker(ServletContext context, String filterName, String id, int maxRequestsPerSecond)
{
_context = context;
_filterName = filterName;
_id = id;
_type = type;
_maxRequestsPerSecond = maxRequestsPerSecond;
_timestamps = new long[maxRequestsPerSecond];
_next = 0;
@ -1212,52 +1146,6 @@ public class DoSFilter implements Filter
return _id;
}
public RateType getType()
{
return _type;
}
@Override
public void valueBound(HttpSessionBindingEvent event)
{
if (LOG.isDebugEnabled())
LOG.debug("Value bound: {}", getId());
_context = event.getSession().getServletContext();
}
@Override
public void valueUnbound(HttpSessionBindingEvent event)
{
//take the tracker out of the list of trackers
DoSFilter filter = (DoSFilter)event.getSession().getServletContext().getAttribute(_filterName);
removeFromRateTrackers(filter, _id);
_context = null;
}
@Override
public void sessionWillPassivate(HttpSessionEvent se)
{
//take the tracker of the list of trackers (if its still there)
DoSFilter filter = (DoSFilter)se.getSession().getServletContext().getAttribute(_filterName);
removeFromRateTrackers(filter, _id);
_context = null;
}
@Override
public void sessionDidActivate(HttpSessionEvent se)
{
RateTracker tracker = (RateTracker)se.getSession().getAttribute(__TRACKER);
ServletContext context = se.getSession().getServletContext();
tracker.setContext(context);
DoSFilter filter = (DoSFilter)context.getAttribute(_filterName);
if (filter == null)
{
LOG.info("No filter {} for rate tracker {}", _filterName, tracker);
return;
}
addToRateTrackers(filter, tracker);
}
public void setContext(ServletContext context)
{
_context = context;
@ -1309,7 +1197,7 @@ public class DoSFilter implements Filter
@Override
public String toString()
{
return "RateTracker/" + _id + "/" + _type;
return "RateTracker/" + _id;
}
public class Overage implements OverLimit
@ -1323,12 +1211,6 @@ public class DoSFilter implements Filter
this.count = count;
}
@Override
public RateType getRateType()
{
return _type;
}
@Override
public String getRateId()
{
@ -1350,23 +1232,20 @@ public class DoSFilter implements Filter
@Override
public String toString()
{
final StringBuilder sb = new StringBuilder(OverLimit.class.getSimpleName());
sb.append('@').append(Integer.toHexString(hashCode()));
sb.append("[type=").append(getRateType());
sb.append(", id=").append(getRateId());
sb.append(", duration=").append(duration);
sb.append(", count=").append(count);
sb.append(']');
return sb.toString();
return OverLimit.class.getSimpleName() + '@' + Integer.toHexString(hashCode()) +
"[id=" + getRateId() +
", duration=" + duration +
", count=" + count +
']';
}
}
}
private static class FixedRateTracker extends RateTracker
{
public FixedRateTracker(ServletContext context, String filterName, String id, RateType type, int numRecentRequestsTracked)
public FixedRateTracker(ServletContext context, String filterName, String id, int numRecentRequestsTracked)
{
super(context, filterName, id, type, numRecentRequestsTracked);
super(context, filterName, id, numRecentRequestsTracked);
}
@Override
@ -1417,17 +1296,14 @@ public class DoSFilter implements Filter
private class DoSAsyncListener extends DoSTimeoutAsyncListener
{
private final RateType priority;
public DoSAsyncListener(RateType priority)
public DoSAsyncListener()
{
this.priority = priority;
}
@Override
public void onTimeout(AsyncEvent event) throws IOException
{
_queues.get(priority).remove(event.getAsyncContext());
_queue.remove(event.getAsyncContext()); // TODO what???
super.onTimeout(event);
}
}
@ -1475,8 +1351,6 @@ public class DoSFilter implements Filter
public interface OverLimit
{
RateType getRateType();
String getRateId();
Duration getDuration();
@ -1503,13 +1377,13 @@ public class DoSFilter implements Filter
switch (action)
{
case REJECT:
LOG.warn("DOS ALERT: Request rejected ip={}, overlimit={}, session={}, user={}", request.getRemoteAddr(), overlimit, request.getRequestedSessionId(), request.getUserPrincipal());
LOG.warn("DoS ALERT: Request rejected ip={}, overlimit={}, user={}", request.getRemoteAddr(), overlimit, request.getUserPrincipal());
break;
case DELAY:
LOG.warn("DOS ALERT: Request delayed={}ms, ip={}, overlimit={}, session={}, user={}", dosFilter.getDelayMs(), request.getRemoteAddr(), overlimit, request.getRequestedSessionId(), request.getUserPrincipal());
LOG.warn("DoS ALERT: Request delayed={}ms, ip={}, overlimit={}, user={}", dosFilter.getDelayMs(), request.getRemoteAddr(), overlimit, request.getUserPrincipal());
break;
case THROTTLE:
LOG.warn("DOS ALERT: Request throttled ip={}, overlimit={}, session={}, user={}", request.getRemoteAddr(), overlimit, request.getRequestedSessionId(), request.getUserPrincipal());
LOG.warn("DoS ALERT: Request throttled ip={}, overlimit={}, user={}", request.getRemoteAddr(), overlimit, request.getUserPrincipal());
break;
}

View File

@ -45,9 +45,7 @@ import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public abstract class AbstractDoSFilterTest
{
@ -248,60 +246,6 @@ public abstract class AbstractDoSFilterTest
other.join();
}
@Test
public void testSessionTracking() throws Exception
{
// get a session, first
String requestSession = "GET /ctx/dos/test?session=true HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
String response = doRequests("", 1, 0, 0, requestSession);
String sessionId = response.substring(response.indexOf("Set-Cookie: ") + 12, response.indexOf(";"));
// all other requests use this session
String request = "GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nCookie: " + sessionId + "\r\n\r\n";
String last = "GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: " + sessionId + "\r\n\r\n";
String responses = doRequests(request + request + request + request + request, 2, 1100, 1100, last);
assertEquals(11, count(responses, "HTTP/1.1 200 OK"));
assertEquals(2, count(responses, "DoSFilter: delayed"));
}
@Test
public void testMultipleSessionTracking() throws Exception
{
// get some session ids, first
String requestSession = "GET /ctx/dos/test?session=true HTTP/1.1\r\nHost: localhost\r\n\r\n";
String closeRequest = "GET /ctx/dos/test?session=true HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
String response = doRequests(requestSession + requestSession, 1, 0, 0, closeRequest);
String[] sessions = response.split("\r\n\r\n");
String sessionId1 = sessions[0].substring(sessions[0].indexOf("Set-Cookie: ") + 12, sessions[0].indexOf(";"));
String sessionId2 = sessions[1].substring(sessions[1].indexOf("Set-Cookie: ") + 12, sessions[1].indexOf(";"));
// alternate between sessions
String request1 = "GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nCookie: " + sessionId1 + "\r\n\r\n";
String request2 = "GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nCookie: " + sessionId2 + "\r\n\r\n";
String last = "GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: " + sessionId2 + "\r\n\r\n";
// ensure the sessions are new
doRequests(request1 + request2, 1, 1100, 1100, last);
Thread.sleep(1000);
String responses = doRequests(request1 + request2 + request1 + request2 + request1, 2, 1100, 1100, last);
assertEquals(11, count(responses, "HTTP/1.1 200 OK"));
// This test is system speed dependent, so allow some (20%-ish) requests to be delayed, but not more.
assertThat("delayed count", count(responses, "DoSFilter: delayed"), lessThan(2));
// alternate between sessions
responses = doRequests(request1 + request2 + request1 + request2 + request1, 2, 250, 250, last);
// System.err.println(responses);
assertEquals(11, count(responses, "HTTP/1.1 200 OK"));
int delayedRequests = count(responses, "DoSFilter: delayed");
assertTrue(delayedRequests >= 2 && delayedRequests <= 5, "delayedRequests: " + delayedRequests + " is not between 2 and 5");
}
@Test
public void testUnresponsiveClient() throws Exception
{

View File

@ -169,7 +169,7 @@ public class DoSFilterTest extends AbstractDoSFilterTest
{
boolean exceeded = false;
ServletContext context = new ContextHandler.StaticContext();
RateTracker rateTracker = new RateTracker(context, doSFilter.getName(), "test2", DoSFilter.RateType.UNKNOWN, 4);
RateTracker rateTracker = new RateTracker(context, doSFilter.getName(), "test2", 4);
for (int i = 0; i < 5; i++)
{