From 168d8715d4148dafe7f818c4bb01dcd0d6e2347f Mon Sep 17 00:00:00 2001 From: Greg Wilkins Date: Wed, 18 Oct 2023 22:34:54 +0200 Subject: [PATCH] Simplify the DosFilter for #1256 (#10748) Use only IP tracking for the DosFilter to fix #1256 Signed-off-by: gregw --- .../asciidoc/old_docs/extras/dos-filter.adoc | 9 +- .../org/eclipse/jetty/servlets/DoSFilter.java | 242 +++++------------- .../jetty/servlets/AbstractDoSFilterTest.java | 56 ---- .../eclipse/jetty/servlets/DoSFilterTest.java | 2 +- 4 files changed, 61 insertions(+), 248 deletions(-) diff --git a/documentation/jetty-documentation/src/main/asciidoc/old_docs/extras/dos-filter.adoc b/documentation/jetty-documentation/src/main/asciidoc/old_docs/extras/dos-filter.adoc index 7961c2d0073..795e752a20f 100644 --- a/documentation/jetty-documentation/src/main/asciidoc/old_docs/extras/dos-filter.adoc +++ b/documentation/jetty-documentation/src/main/asciidoc/old_docs/extras/dos-filter.adoc @@ -32,9 +32,7 @@ The filter works on the assumption that the attacker might be written in simple [[dos-filter-using]] ==== Using the DoS Filter -Jetty places throttled requests in a priority queue, giving priority first to authenticated users and users with an HttpSession, then to connections identified by their IP addresses. -Connections with no way to identify them have lowest priority. -To uniquely identify authenticated users, you should implement the The extractUserId(ServletRequest request) function. +Jetty places throttled requests in a queue, and proceed only when there is capacity available. ===== Required JARs @@ -94,11 +92,8 @@ Default is 30000L. insertHeaders:: If true, insert the DoSFilter headers into the response. Defaults to true. -trackSessions:: -If true, usage rate is tracked by session if a session exists. -Defaults to true. remotePort:: -If true and session tracking is not used, then rate is tracked by IP and port (effectively connection). +If true, then rate is tracked by IP and port (effectively connection). Defaults to false. ipWhitelist:: A comma-separated list of IP addresses that will not be rate limited. diff --git a/jetty-servlets/src/main/java/org/eclipse/jetty/servlets/DoSFilter.java b/jetty-servlets/src/main/java/org/eclipse/jetty/servlets/DoSFilter.java index 2f08d4b9977..f3f6a09e02b 100644 --- a/jetty-servlets/src/main/java/org/eclipse/jetty/servlets/DoSFilter.java +++ b/jetty-servlets/src/main/java/org/eclipse/jetty/servlets/DoSFilter.java @@ -17,10 +17,8 @@ import java.io.IOException; import java.io.Serializable; import java.time.Duration; import java.util.ArrayList; -import java.util.HashMap; import java.util.Iterator; import java.util.List; -import java.util.Map; import java.util.Objects; import java.util.Queue; import java.util.concurrent.ConcurrentHashMap; @@ -43,11 +41,6 @@ import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import javax.servlet.http.HttpSession; -import javax.servlet.http.HttpSessionActivationListener; -import javax.servlet.http.HttpSessionBindingEvent; -import javax.servlet.http.HttpSessionBindingListener; -import javax.servlet.http.HttpSessionEvent; import org.eclipse.jetty.http.HttpStatus; import org.eclipse.jetty.server.handler.ContextHandler; @@ -74,10 +67,8 @@ import org.slf4j.LoggerFactory; * second. If a limit is exceeded, the request is either rejected, delayed, or * throttled. *

- * When a request is throttled, it is placed in a priority queue. Priority is - * given first to authenticated users and users with an HttpSession, then - * connections which can be identified by their IP addresses. Connections with - * no way to identify them are given lowest priority. + * When a request is throttled, it is placed in a queue and will only proceed + * when there is capacity. *

* The {@link #extractUserId(ServletRequest request)} function should be * implemented, in order to uniquely identify authenticated users. @@ -106,10 +97,8 @@ import org.slf4j.LoggerFactory; * before deciding that the user has gone away, and discarding it *

insertHeaders
*
if true , insert the DoSFilter headers into the response. Defaults to true.
- *
trackSessions
- *
if true, usage rate is tracked by session if a session exists. Defaults to true.
*
remotePort
- *
if true and session tracking is not used, then rate is tracked by IP+port (effectively connection). Defaults to false.
+ *
if true then rate is tracked by IP+port (effectively connection). Defaults to false.
*
ipWhitelist
*
a comma-separated list of IP addresses that will not be rate limited
*
managedAttr
@@ -156,12 +145,14 @@ public class DoSFilter implements Filter static final String MAX_REQUEST_MS_INIT_PARAM = "maxRequestMs"; static final String MAX_IDLE_TRACKER_MS_INIT_PARAM = "maxIdleTrackerMs"; static final String INSERT_HEADERS_INIT_PARAM = "insertHeaders"; + @Deprecated static final String TRACK_SESSIONS_INIT_PARAM = "trackSessions"; static final String REMOTE_PORT_INIT_PARAM = "remotePort"; static final String IP_WHITELIST_INIT_PARAM = "ipWhitelist"; static final String ENABLED_INIT_PARAM = "enabled"; static final String TOO_MANY_CODE = "tooManyCode"; + @Deprecated public enum RateType { AUTH, @@ -181,7 +172,6 @@ public class DoSFilter implements Filter private volatile long _maxRequestMs; private volatile long _maxIdleTrackerMs; private volatile boolean _insertHeaders; - private volatile boolean _trackSessions; private volatile boolean _remotePort; private volatile boolean _enabled; private volatile String _name; @@ -189,20 +179,14 @@ public class DoSFilter implements Filter private Semaphore _passes; private volatile int _throttledRequests; private volatile int _maxRequestsPerSec; - private Map> _queues = new HashMap<>(); - private Map _listeners = new HashMap<>(); + private final Queue _queue = new ConcurrentLinkedQueue<>(); + private final AsyncListener _asyncListener = new DoSAsyncListener(); private Scheduler _scheduler; private ServletContext _context; @Override public void init(FilterConfig filterConfig) throws ServletException { - for (RateType rateType : RateType.values()) - { - _queues.put(rateType, new ConcurrentLinkedQueue<>()); - _listeners.put(rateType, new DoSAsyncListener(rateType)); - } - _rateTrackers.clear(); int maxRequests = __DEFAULT_MAX_REQUESTS_PER_SEC; @@ -395,15 +379,14 @@ public class DoSFilter implements Filter long throttleMs = getThrottleMs(); if (!Boolean.TRUE.equals(throttled) && throttleMs > 0) { - RateType priority = getPriority(request, tracker); request.setAttribute(__THROTTLED, Boolean.TRUE); if (isInsertHeaders()) response.addHeader("DoSFilter", "throttled"); AsyncContext asyncContext = request.startAsync(); request.setAttribute(_suspended, Boolean.TRUE); asyncContext.setTimeout(throttleMs); - asyncContext.addListener(_listeners.get(priority)); - _queues.get(priority).add(asyncContext); + asyncContext.addListener(_asyncListener); + _queue.add(asyncContext); if (LOG.isDebugEnabled()) LOG.debug("Throttled {}, {}ms", request, throttleMs); return; @@ -447,22 +430,17 @@ public class DoSFilter implements Filter { try { - // Wake up the next highest priority request. - for (RateType rateType : RateType.values()) + AsyncContext asyncContext = _queue.poll(); + if (asyncContext != null) { - AsyncContext asyncContext = _queues.get(rateType).poll(); - if (asyncContext != null) + ServletRequest candidate = asyncContext.getRequest(); + Boolean suspended = (Boolean)candidate.getAttribute(_suspended); + if (Boolean.TRUE.equals(suspended)) { - ServletRequest candidate = asyncContext.getRequest(); - Boolean suspended = (Boolean)candidate.getAttribute(_suspended); - if (Boolean.TRUE.equals(suspended)) - { - if (LOG.isDebugEnabled()) - LOG.debug("Resuming {}", request); - candidate.setAttribute(_resumed, Boolean.TRUE); - asyncContext.dispatch(); - break; - } + if (LOG.isDebugEnabled()) + LOG.debug("Resuming {}", request); + candidate.setAttribute(_resumed, Boolean.TRUE); + asyncContext.dispatch(); } } } @@ -524,27 +502,13 @@ public class DoSFilter implements Filter } /** - * Get priority for this request, based on user type - * - * @param request the current request - * @param tracker the rate tracker for this request - * @return the priority for this request - */ - private RateType getPriority(HttpServletRequest request, RateTracker tracker) - { - if (extractUserId(request) != null) - return RateType.AUTH; - if (tracker != null) - return tracker.getType(); - return RateType.UNKNOWN; - } - - /** - * @return the maximum priority that we can assign to a request + * @return null + * @deprecated Priority no longer supported */ + @Deprecated protected RateType getMaxPriority() { - return RateType.AUTH; + return null; } public void setListener(DoSFilter.Listener listener) @@ -570,61 +534,29 @@ public class DoSFilter implements Filter *

* Assumes that each connection has an identifying characteristic, and goes * through them in order, taking the first that matches: user id (logged - * in), session id, client IP address. Unidentifiable connections are lumped + * in), client IP address. Unidentifiable connections are lumped * into one. - *

- * When a session expires, its rate tracker is automatically deleted. * * @param request the current request * @return the request rate tracker for the current connection */ RateTracker getRateTracker(ServletRequest request) { - HttpSession session = ((HttpServletRequest)request).getSession(false); - - String loadId = extractUserId(request); - final RateType type; - if (loadId != null) - { - type = RateType.AUTH; - } - else - { - if (isTrackSessions() && session != null && !session.isNew()) - { - loadId = session.getId(); - type = RateType.SESSION; - } - else - { - loadId = isRemotePort() ? createRemotePortId(request) : request.getRemoteAddr(); - type = RateType.IP; - } - } - + String loadId = isRemotePort() ? createRemotePortId(request) : request.getRemoteAddr(); RateTracker tracker = _rateTrackers.get(loadId); if (tracker == null) { boolean allowed = checkWhitelist(request.getRemoteAddr()); int maxRequestsPerSec = getMaxRequestsPerSec(); - tracker = allowed ? new FixedRateTracker(_context, _name, loadId, type, maxRequestsPerSec) - : new RateTracker(_context, _name, loadId, type, maxRequestsPerSec); + tracker = allowed ? new FixedRateTracker(_context, _name, loadId, maxRequestsPerSec) + : new RateTracker(_context, _name, loadId, maxRequestsPerSec); tracker.setContext(_context); RateTracker existing = _rateTrackers.putIfAbsent(loadId, tracker); if (existing != null) tracker = existing; - if (type == RateType.IP) - { - // USER_IP expiration from _rateTrackers is handled by the _scheduler - _scheduler.schedule(tracker, getMaxIdleTrackerMs(), TimeUnit.MILLISECONDS); - } - else if (session != null) - { - // USER_SESSION expiration from _rateTrackers are handled by the HttpSessionBindingListener - session.setAttribute(__TRACKER, tracker); - } + _scheduler.schedule(tracker, getMaxIdleTrackerMs(), TimeUnit.MILLISECONDS); } return tracker; @@ -750,7 +682,7 @@ public class DoSFilter implements Filter return result; // Sets the _prefix_ most significant bits to 1 - result[index] = (byte)~((1 << (8 - prefix)) - 1); + result[index] = (byte)-(1 << (8 - prefix)); return result; } @@ -776,12 +708,11 @@ public class DoSFilter implements Filter } /** - * Returns the user id, used to track this connection. - * This SHOULD be overridden by subclasses. - * - * @param request the current request - * @return a unique user id, if logged in; otherwise null. + * @param request ignored + * @return null + * @deprecated User ID no longer supported */ + @Deprecated protected String extractUserId(ServletRequest request) { return null; @@ -996,26 +927,29 @@ public class DoSFilter implements Filter * Get flag to have usage rate tracked by session if a session exists. * * @return value of the flag + * @deprecated Session tracking is no longer supported */ - @ManagedAttribute("usage rate is tracked by session if one exists") + @Deprecated public boolean isTrackSessions() { - return _trackSessions; + return false; } /** * Set flag to have usage rate tracked by session if a session exists. * * @param value value of the flag + * @deprecated Session tracking is no longer supported */ + @Deprecated public void setTrackSessions(boolean value) { - _trackSessions = value; + if (value) + LOG.warn("Session Tracking is not supported"); } /** * Get flag to have usage rate tracked by IP+port (effectively connection) - * if session tracking is not used. * * @return value of the flag */ @@ -1027,7 +961,6 @@ public class DoSFilter implements Filter /** * Set flag to have usage rate tracked by IP+port (effectively connection) - * if session tracking is not used. * * @param value value of the flag */ @@ -1130,7 +1063,10 @@ public class DoSFilter implements Filter private boolean addWhitelistAddress(List list, String address) { address = address.trim(); - return address.length() > 0 && list.add(address); + if (address.length() <= 0) + return false; + list.add(address); + return true; } /** @@ -1157,7 +1093,7 @@ public class DoSFilter implements Filter * A RateTracker is associated with a connection, and stores request rate * data. */ - static class RateTracker implements Runnable, HttpSessionBindingListener, HttpSessionActivationListener, Serializable + static class RateTracker implements Runnable, Serializable { private static final long serialVersionUID = 3534663738034577872L; @@ -1165,18 +1101,16 @@ public class DoSFilter implements Filter protected final String _filterName; protected transient ServletContext _context; protected final String _id; - protected final RateType _type; protected final int _maxRequestsPerSecond; protected final long[] _timestamps; protected int _next; - public RateTracker(ServletContext context, String filterName, String id, RateType type, int maxRequestsPerSecond) + RateTracker(ServletContext context, String filterName, String id, int maxRequestsPerSecond) { _context = context; _filterName = filterName; _id = id; - _type = type; _maxRequestsPerSecond = maxRequestsPerSecond; _timestamps = new long[maxRequestsPerSecond]; _next = 0; @@ -1212,52 +1146,6 @@ public class DoSFilter implements Filter return _id; } - public RateType getType() - { - return _type; - } - - @Override - public void valueBound(HttpSessionBindingEvent event) - { - if (LOG.isDebugEnabled()) - LOG.debug("Value bound: {}", getId()); - _context = event.getSession().getServletContext(); - } - - @Override - public void valueUnbound(HttpSessionBindingEvent event) - { - //take the tracker out of the list of trackers - DoSFilter filter = (DoSFilter)event.getSession().getServletContext().getAttribute(_filterName); - removeFromRateTrackers(filter, _id); - _context = null; - } - - @Override - public void sessionWillPassivate(HttpSessionEvent se) - { - //take the tracker of the list of trackers (if its still there) - DoSFilter filter = (DoSFilter)se.getSession().getServletContext().getAttribute(_filterName); - removeFromRateTrackers(filter, _id); - _context = null; - } - - @Override - public void sessionDidActivate(HttpSessionEvent se) - { - RateTracker tracker = (RateTracker)se.getSession().getAttribute(__TRACKER); - ServletContext context = se.getSession().getServletContext(); - tracker.setContext(context); - DoSFilter filter = (DoSFilter)context.getAttribute(_filterName); - if (filter == null) - { - LOG.info("No filter {} for rate tracker {}", _filterName, tracker); - return; - } - addToRateTrackers(filter, tracker); - } - public void setContext(ServletContext context) { _context = context; @@ -1309,7 +1197,7 @@ public class DoSFilter implements Filter @Override public String toString() { - return "RateTracker/" + _id + "/" + _type; + return "RateTracker/" + _id; } public class Overage implements OverLimit @@ -1323,12 +1211,6 @@ public class DoSFilter implements Filter this.count = count; } - @Override - public RateType getRateType() - { - return _type; - } - @Override public String getRateId() { @@ -1350,23 +1232,20 @@ public class DoSFilter implements Filter @Override public String toString() { - final StringBuilder sb = new StringBuilder(OverLimit.class.getSimpleName()); - sb.append('@').append(Integer.toHexString(hashCode())); - sb.append("[type=").append(getRateType()); - sb.append(", id=").append(getRateId()); - sb.append(", duration=").append(duration); - sb.append(", count=").append(count); - sb.append(']'); - return sb.toString(); + return OverLimit.class.getSimpleName() + '@' + Integer.toHexString(hashCode()) + + "[id=" + getRateId() + + ", duration=" + duration + + ", count=" + count + + ']'; } } } private static class FixedRateTracker extends RateTracker { - public FixedRateTracker(ServletContext context, String filterName, String id, RateType type, int numRecentRequestsTracked) + public FixedRateTracker(ServletContext context, String filterName, String id, int numRecentRequestsTracked) { - super(context, filterName, id, type, numRecentRequestsTracked); + super(context, filterName, id, numRecentRequestsTracked); } @Override @@ -1417,17 +1296,14 @@ public class DoSFilter implements Filter private class DoSAsyncListener extends DoSTimeoutAsyncListener { - private final RateType priority; - - public DoSAsyncListener(RateType priority) + public DoSAsyncListener() { - this.priority = priority; } @Override public void onTimeout(AsyncEvent event) throws IOException { - _queues.get(priority).remove(event.getAsyncContext()); + _queue.remove(event.getAsyncContext()); // TODO what??? super.onTimeout(event); } } @@ -1475,8 +1351,6 @@ public class DoSFilter implements Filter public interface OverLimit { - RateType getRateType(); - String getRateId(); Duration getDuration(); @@ -1503,13 +1377,13 @@ public class DoSFilter implements Filter switch (action) { case REJECT: - LOG.warn("DOS ALERT: Request rejected ip={}, overlimit={}, session={}, user={}", request.getRemoteAddr(), overlimit, request.getRequestedSessionId(), request.getUserPrincipal()); + LOG.warn("DoS ALERT: Request rejected ip={}, overlimit={}, user={}", request.getRemoteAddr(), overlimit, request.getUserPrincipal()); break; case DELAY: - LOG.warn("DOS ALERT: Request delayed={}ms, ip={}, overlimit={}, session={}, user={}", dosFilter.getDelayMs(), request.getRemoteAddr(), overlimit, request.getRequestedSessionId(), request.getUserPrincipal()); + LOG.warn("DoS ALERT: Request delayed={}ms, ip={}, overlimit={}, user={}", dosFilter.getDelayMs(), request.getRemoteAddr(), overlimit, request.getUserPrincipal()); break; case THROTTLE: - LOG.warn("DOS ALERT: Request throttled ip={}, overlimit={}, session={}, user={}", request.getRemoteAddr(), overlimit, request.getRequestedSessionId(), request.getUserPrincipal()); + LOG.warn("DoS ALERT: Request throttled ip={}, overlimit={}, user={}", request.getRemoteAddr(), overlimit, request.getUserPrincipal()); break; } diff --git a/jetty-servlets/src/test/java/org/eclipse/jetty/servlets/AbstractDoSFilterTest.java b/jetty-servlets/src/test/java/org/eclipse/jetty/servlets/AbstractDoSFilterTest.java index 9cbe919811c..9e8f9aeecf5 100644 --- a/jetty-servlets/src/test/java/org/eclipse/jetty/servlets/AbstractDoSFilterTest.java +++ b/jetty-servlets/src/test/java/org/eclipse/jetty/servlets/AbstractDoSFilterTest.java @@ -45,9 +45,7 @@ import org.junit.jupiter.api.Test; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.lessThan; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; public abstract class AbstractDoSFilterTest { @@ -248,60 +246,6 @@ public abstract class AbstractDoSFilterTest other.join(); } - @Test - public void testSessionTracking() throws Exception - { - // get a session, first - String requestSession = "GET /ctx/dos/test?session=true HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"; - String response = doRequests("", 1, 0, 0, requestSession); - String sessionId = response.substring(response.indexOf("Set-Cookie: ") + 12, response.indexOf(";")); - - // all other requests use this session - String request = "GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nCookie: " + sessionId + "\r\n\r\n"; - String last = "GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: " + sessionId + "\r\n\r\n"; - String responses = doRequests(request + request + request + request + request, 2, 1100, 1100, last); - - assertEquals(11, count(responses, "HTTP/1.1 200 OK")); - assertEquals(2, count(responses, "DoSFilter: delayed")); - } - - @Test - public void testMultipleSessionTracking() throws Exception - { - // get some session ids, first - String requestSession = "GET /ctx/dos/test?session=true HTTP/1.1\r\nHost: localhost\r\n\r\n"; - String closeRequest = "GET /ctx/dos/test?session=true HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"; - String response = doRequests(requestSession + requestSession, 1, 0, 0, closeRequest); - - String[] sessions = response.split("\r\n\r\n"); - - String sessionId1 = sessions[0].substring(sessions[0].indexOf("Set-Cookie: ") + 12, sessions[0].indexOf(";")); - String sessionId2 = sessions[1].substring(sessions[1].indexOf("Set-Cookie: ") + 12, sessions[1].indexOf(";")); - - // alternate between sessions - String request1 = "GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nCookie: " + sessionId1 + "\r\n\r\n"; - String request2 = "GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nCookie: " + sessionId2 + "\r\n\r\n"; - String last = "GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: " + sessionId2 + "\r\n\r\n"; - - // ensure the sessions are new - doRequests(request1 + request2, 1, 1100, 1100, last); - Thread.sleep(1000); - - String responses = doRequests(request1 + request2 + request1 + request2 + request1, 2, 1100, 1100, last); - - assertEquals(11, count(responses, "HTTP/1.1 200 OK")); - // This test is system speed dependent, so allow some (20%-ish) requests to be delayed, but not more. - assertThat("delayed count", count(responses, "DoSFilter: delayed"), lessThan(2)); - - // alternate between sessions - responses = doRequests(request1 + request2 + request1 + request2 + request1, 2, 250, 250, last); - - // System.err.println(responses); - assertEquals(11, count(responses, "HTTP/1.1 200 OK")); - int delayedRequests = count(responses, "DoSFilter: delayed"); - assertTrue(delayedRequests >= 2 && delayedRequests <= 5, "delayedRequests: " + delayedRequests + " is not between 2 and 5"); - } - @Test public void testUnresponsiveClient() throws Exception { diff --git a/jetty-servlets/src/test/java/org/eclipse/jetty/servlets/DoSFilterTest.java b/jetty-servlets/src/test/java/org/eclipse/jetty/servlets/DoSFilterTest.java index dc70e1c98fc..ce881294794 100644 --- a/jetty-servlets/src/test/java/org/eclipse/jetty/servlets/DoSFilterTest.java +++ b/jetty-servlets/src/test/java/org/eclipse/jetty/servlets/DoSFilterTest.java @@ -169,7 +169,7 @@ public class DoSFilterTest extends AbstractDoSFilterTest { boolean exceeded = false; ServletContext context = new ContextHandler.StaticContext(); - RateTracker rateTracker = new RateTracker(context, doSFilter.getName(), "test2", DoSFilter.RateType.UNKNOWN, 4); + RateTracker rateTracker = new RateTracker(context, doSFilter.getName(), "test2", 4); for (int i = 0; i < 5; i++) {