Simplify the DosFilter for #1256 (#10748)

Use only IP tracking for the DosFilter to fix #1256

Signed-off-by: gregw <gregw@webtide.com>
This commit is contained in:
Greg Wilkins 2023-10-18 22:34:54 +02:00 committed by GitHub
parent 8c94490e18
commit 168d8715d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 61 additions and 248 deletions

View File

@ -32,9 +32,7 @@ The filter works on the assumption that the attacker might be written in simple
[[dos-filter-using]] [[dos-filter-using]]
==== Using the DoS Filter ==== Using the DoS Filter
Jetty places throttled requests in a priority queue, giving priority first to authenticated users and users with an HttpSession, then to connections identified by their IP addresses. Jetty places throttled requests in a queue, and proceed only when there is capacity available.
Connections with no way to identify them have lowest priority.
To uniquely identify authenticated users, you should implement the The extractUserId(ServletRequest request) function.
===== Required JARs ===== Required JARs
@ -94,11 +92,8 @@ Default is 30000L.
insertHeaders:: insertHeaders::
If true, insert the DoSFilter headers into the response. If true, insert the DoSFilter headers into the response.
Defaults to true. Defaults to true.
trackSessions::
If true, usage rate is tracked by session if a session exists.
Defaults to true.
remotePort:: remotePort::
If true and session tracking is not used, then rate is tracked by IP and port (effectively connection). If true, then rate is tracked by IP and port (effectively connection).
Defaults to false. Defaults to false.
ipWhitelist:: ipWhitelist::
A comma-separated list of IP addresses that will not be rate limited. A comma-separated list of IP addresses that will not be rate limited.

View File

@ -17,10 +17,8 @@ import java.io.IOException;
import java.io.Serializable; import java.io.Serializable;
import java.time.Duration; import java.time.Duration;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Queue; import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
@ -43,11 +41,6 @@ import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse; import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import javax.servlet.http.HttpSessionActivationListener;
import javax.servlet.http.HttpSessionBindingEvent;
import javax.servlet.http.HttpSessionBindingListener;
import javax.servlet.http.HttpSessionEvent;
import org.eclipse.jetty.http.HttpStatus; import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.server.handler.ContextHandler; import org.eclipse.jetty.server.handler.ContextHandler;
@ -74,10 +67,8 @@ import org.slf4j.LoggerFactory;
* second. If a limit is exceeded, the request is either rejected, delayed, or * second. If a limit is exceeded, the request is either rejected, delayed, or
* throttled. * throttled.
* <p> * <p>
* When a request is throttled, it is placed in a priority queue. Priority is * When a request is throttled, it is placed in a queue and will only proceed
* given first to authenticated users and users with an HttpSession, then * when there is capacity.
* connections which can be identified by their IP addresses. Connections with
* no way to identify them are given lowest priority.
* <p> * <p>
* The {@link #extractUserId(ServletRequest request)} function should be * The {@link #extractUserId(ServletRequest request)} function should be
* implemented, in order to uniquely identify authenticated users. * implemented, in order to uniquely identify authenticated users.
@ -106,10 +97,8 @@ import org.slf4j.LoggerFactory;
* before deciding that the user has gone away, and discarding it</dd> * before deciding that the user has gone away, and discarding it</dd>
* <dt>insertHeaders</dt> * <dt>insertHeaders</dt>
* <dd>if true , insert the DoSFilter headers into the response. Defaults to true.</dd> * <dd>if true , insert the DoSFilter headers into the response. Defaults to true.</dd>
* <dt>trackSessions</dt>
* <dd>if true, usage rate is tracked by session if a session exists. Defaults to true.</dd>
* <dt>remotePort</dt> * <dt>remotePort</dt>
* <dd>if true and session tracking is not used, then rate is tracked by IP+port (effectively connection). Defaults to false.</dd> * <dd>if true then rate is tracked by IP+port (effectively connection). Defaults to false.</dd>
* <dt>ipWhitelist</dt> * <dt>ipWhitelist</dt>
* <dd>a comma-separated list of IP addresses that will not be rate limited</dd> * <dd>a comma-separated list of IP addresses that will not be rate limited</dd>
* <dt>managedAttr</dt> * <dt>managedAttr</dt>
@ -156,12 +145,14 @@ public class DoSFilter implements Filter
static final String MAX_REQUEST_MS_INIT_PARAM = "maxRequestMs"; static final String MAX_REQUEST_MS_INIT_PARAM = "maxRequestMs";
static final String MAX_IDLE_TRACKER_MS_INIT_PARAM = "maxIdleTrackerMs"; static final String MAX_IDLE_TRACKER_MS_INIT_PARAM = "maxIdleTrackerMs";
static final String INSERT_HEADERS_INIT_PARAM = "insertHeaders"; static final String INSERT_HEADERS_INIT_PARAM = "insertHeaders";
@Deprecated
static final String TRACK_SESSIONS_INIT_PARAM = "trackSessions"; static final String TRACK_SESSIONS_INIT_PARAM = "trackSessions";
static final String REMOTE_PORT_INIT_PARAM = "remotePort"; static final String REMOTE_PORT_INIT_PARAM = "remotePort";
static final String IP_WHITELIST_INIT_PARAM = "ipWhitelist"; static final String IP_WHITELIST_INIT_PARAM = "ipWhitelist";
static final String ENABLED_INIT_PARAM = "enabled"; static final String ENABLED_INIT_PARAM = "enabled";
static final String TOO_MANY_CODE = "tooManyCode"; static final String TOO_MANY_CODE = "tooManyCode";
@Deprecated
public enum RateType public enum RateType
{ {
AUTH, AUTH,
@ -181,7 +172,6 @@ public class DoSFilter implements Filter
private volatile long _maxRequestMs; private volatile long _maxRequestMs;
private volatile long _maxIdleTrackerMs; private volatile long _maxIdleTrackerMs;
private volatile boolean _insertHeaders; private volatile boolean _insertHeaders;
private volatile boolean _trackSessions;
private volatile boolean _remotePort; private volatile boolean _remotePort;
private volatile boolean _enabled; private volatile boolean _enabled;
private volatile String _name; private volatile String _name;
@ -189,20 +179,14 @@ public class DoSFilter implements Filter
private Semaphore _passes; private Semaphore _passes;
private volatile int _throttledRequests; private volatile int _throttledRequests;
private volatile int _maxRequestsPerSec; private volatile int _maxRequestsPerSec;
private Map<RateType, Queue<AsyncContext>> _queues = new HashMap<>(); private final Queue<AsyncContext> _queue = new ConcurrentLinkedQueue<>();
private Map<RateType, AsyncListener> _listeners = new HashMap<>(); private final AsyncListener _asyncListener = new DoSAsyncListener();
private Scheduler _scheduler; private Scheduler _scheduler;
private ServletContext _context; private ServletContext _context;
@Override @Override
public void init(FilterConfig filterConfig) throws ServletException public void init(FilterConfig filterConfig) throws ServletException
{ {
for (RateType rateType : RateType.values())
{
_queues.put(rateType, new ConcurrentLinkedQueue<>());
_listeners.put(rateType, new DoSAsyncListener(rateType));
}
_rateTrackers.clear(); _rateTrackers.clear();
int maxRequests = __DEFAULT_MAX_REQUESTS_PER_SEC; int maxRequests = __DEFAULT_MAX_REQUESTS_PER_SEC;
@ -395,15 +379,14 @@ public class DoSFilter implements Filter
long throttleMs = getThrottleMs(); long throttleMs = getThrottleMs();
if (!Boolean.TRUE.equals(throttled) && throttleMs > 0) if (!Boolean.TRUE.equals(throttled) && throttleMs > 0)
{ {
RateType priority = getPriority(request, tracker);
request.setAttribute(__THROTTLED, Boolean.TRUE); request.setAttribute(__THROTTLED, Boolean.TRUE);
if (isInsertHeaders()) if (isInsertHeaders())
response.addHeader("DoSFilter", "throttled"); response.addHeader("DoSFilter", "throttled");
AsyncContext asyncContext = request.startAsync(); AsyncContext asyncContext = request.startAsync();
request.setAttribute(_suspended, Boolean.TRUE); request.setAttribute(_suspended, Boolean.TRUE);
asyncContext.setTimeout(throttleMs); asyncContext.setTimeout(throttleMs);
asyncContext.addListener(_listeners.get(priority)); asyncContext.addListener(_asyncListener);
_queues.get(priority).add(asyncContext); _queue.add(asyncContext);
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("Throttled {}, {}ms", request, throttleMs); LOG.debug("Throttled {}, {}ms", request, throttleMs);
return; return;
@ -447,10 +430,7 @@ public class DoSFilter implements Filter
{ {
try try
{ {
// Wake up the next highest priority request. AsyncContext asyncContext = _queue.poll();
for (RateType rateType : RateType.values())
{
AsyncContext asyncContext = _queues.get(rateType).poll();
if (asyncContext != null) if (asyncContext != null)
{ {
ServletRequest candidate = asyncContext.getRequest(); ServletRequest candidate = asyncContext.getRequest();
@ -461,8 +441,6 @@ public class DoSFilter implements Filter
LOG.debug("Resuming {}", request); LOG.debug("Resuming {}", request);
candidate.setAttribute(_resumed, Boolean.TRUE); candidate.setAttribute(_resumed, Boolean.TRUE);
asyncContext.dispatch(); asyncContext.dispatch();
break;
}
} }
} }
} }
@ -524,27 +502,13 @@ public class DoSFilter implements Filter
} }
/** /**
* Get priority for this request, based on user type * @return null
* * @deprecated Priority no longer supported
* @param request the current request
* @param tracker the rate tracker for this request
* @return the priority for this request
*/
private RateType getPriority(HttpServletRequest request, RateTracker tracker)
{
if (extractUserId(request) != null)
return RateType.AUTH;
if (tracker != null)
return tracker.getType();
return RateType.UNKNOWN;
}
/**
* @return the maximum priority that we can assign to a request
*/ */
@Deprecated
protected RateType getMaxPriority() protected RateType getMaxPriority()
{ {
return RateType.AUTH; return null;
} }
public void setListener(DoSFilter.Listener listener) public void setListener(DoSFilter.Listener listener)
@ -570,62 +534,30 @@ public class DoSFilter implements Filter
* <p> * <p>
* Assumes that each connection has an identifying characteristic, and goes * Assumes that each connection has an identifying characteristic, and goes
* through them in order, taking the first that matches: user id (logged * through them in order, taking the first that matches: user id (logged
* in), session id, client IP address. Unidentifiable connections are lumped * in), client IP address. Unidentifiable connections are lumped
* into one. * into one.
* <p>
* When a session expires, its rate tracker is automatically deleted.
* *
* @param request the current request * @param request the current request
* @return the request rate tracker for the current connection * @return the request rate tracker for the current connection
*/ */
RateTracker getRateTracker(ServletRequest request) RateTracker getRateTracker(ServletRequest request)
{ {
HttpSession session = ((HttpServletRequest)request).getSession(false); String loadId = isRemotePort() ? createRemotePortId(request) : request.getRemoteAddr();
String loadId = extractUserId(request);
final RateType type;
if (loadId != null)
{
type = RateType.AUTH;
}
else
{
if (isTrackSessions() && session != null && !session.isNew())
{
loadId = session.getId();
type = RateType.SESSION;
}
else
{
loadId = isRemotePort() ? createRemotePortId(request) : request.getRemoteAddr();
type = RateType.IP;
}
}
RateTracker tracker = _rateTrackers.get(loadId); RateTracker tracker = _rateTrackers.get(loadId);
if (tracker == null) if (tracker == null)
{ {
boolean allowed = checkWhitelist(request.getRemoteAddr()); boolean allowed = checkWhitelist(request.getRemoteAddr());
int maxRequestsPerSec = getMaxRequestsPerSec(); int maxRequestsPerSec = getMaxRequestsPerSec();
tracker = allowed ? new FixedRateTracker(_context, _name, loadId, type, maxRequestsPerSec) tracker = allowed ? new FixedRateTracker(_context, _name, loadId, maxRequestsPerSec)
: new RateTracker(_context, _name, loadId, type, maxRequestsPerSec); : new RateTracker(_context, _name, loadId, maxRequestsPerSec);
tracker.setContext(_context); tracker.setContext(_context);
RateTracker existing = _rateTrackers.putIfAbsent(loadId, tracker); RateTracker existing = _rateTrackers.putIfAbsent(loadId, tracker);
if (existing != null) if (existing != null)
tracker = existing; tracker = existing;
if (type == RateType.IP)
{
// USER_IP expiration from _rateTrackers is handled by the _scheduler
_scheduler.schedule(tracker, getMaxIdleTrackerMs(), TimeUnit.MILLISECONDS); _scheduler.schedule(tracker, getMaxIdleTrackerMs(), TimeUnit.MILLISECONDS);
} }
else if (session != null)
{
// USER_SESSION expiration from _rateTrackers are handled by the HttpSessionBindingListener
session.setAttribute(__TRACKER, tracker);
}
}
return tracker; return tracker;
} }
@ -750,7 +682,7 @@ public class DoSFilter implements Filter
return result; return result;
// Sets the _prefix_ most significant bits to 1 // Sets the _prefix_ most significant bits to 1
result[index] = (byte)~((1 << (8 - prefix)) - 1); result[index] = (byte)-(1 << (8 - prefix));
return result; return result;
} }
@ -776,12 +708,11 @@ public class DoSFilter implements Filter
} }
/** /**
* Returns the user id, used to track this connection. * @param request ignored
* This SHOULD be overridden by subclasses. * @return null
* * @deprecated User ID no longer supported
* @param request the current request
* @return a unique user id, if logged in; otherwise null.
*/ */
@Deprecated
protected String extractUserId(ServletRequest request) protected String extractUserId(ServletRequest request)
{ {
return null; return null;
@ -996,26 +927,29 @@ public class DoSFilter implements Filter
* Get flag to have usage rate tracked by session if a session exists. * Get flag to have usage rate tracked by session if a session exists.
* *
* @return value of the flag * @return value of the flag
* @deprecated Session tracking is no longer supported
*/ */
@ManagedAttribute("usage rate is tracked by session if one exists") @Deprecated
public boolean isTrackSessions() public boolean isTrackSessions()
{ {
return _trackSessions; return false;
} }
/** /**
* Set flag to have usage rate tracked by session if a session exists. * Set flag to have usage rate tracked by session if a session exists.
* *
* @param value value of the flag * @param value value of the flag
* @deprecated Session tracking is no longer supported
*/ */
@Deprecated
public void setTrackSessions(boolean value) public void setTrackSessions(boolean value)
{ {
_trackSessions = value; if (value)
LOG.warn("Session Tracking is not supported");
} }
/** /**
* Get flag to have usage rate tracked by IP+port (effectively connection) * Get flag to have usage rate tracked by IP+port (effectively connection)
* if session tracking is not used.
* *
* @return value of the flag * @return value of the flag
*/ */
@ -1027,7 +961,6 @@ public class DoSFilter implements Filter
/** /**
* Set flag to have usage rate tracked by IP+port (effectively connection) * Set flag to have usage rate tracked by IP+port (effectively connection)
* if session tracking is not used.
* *
* @param value value of the flag * @param value value of the flag
*/ */
@ -1130,7 +1063,10 @@ public class DoSFilter implements Filter
private boolean addWhitelistAddress(List<String> list, String address) private boolean addWhitelistAddress(List<String> list, String address)
{ {
address = address.trim(); address = address.trim();
return address.length() > 0 && list.add(address); if (address.length() <= 0)
return false;
list.add(address);
return true;
} }
/** /**
@ -1157,7 +1093,7 @@ public class DoSFilter implements Filter
* A RateTracker is associated with a connection, and stores request rate * A RateTracker is associated with a connection, and stores request rate
* data. * data.
*/ */
static class RateTracker implements Runnable, HttpSessionBindingListener, HttpSessionActivationListener, Serializable static class RateTracker implements Runnable, Serializable
{ {
private static final long serialVersionUID = 3534663738034577872L; private static final long serialVersionUID = 3534663738034577872L;
@ -1165,18 +1101,16 @@ public class DoSFilter implements Filter
protected final String _filterName; protected final String _filterName;
protected transient ServletContext _context; protected transient ServletContext _context;
protected final String _id; protected final String _id;
protected final RateType _type;
protected final int _maxRequestsPerSecond; protected final int _maxRequestsPerSecond;
protected final long[] _timestamps; protected final long[] _timestamps;
protected int _next; protected int _next;
public RateTracker(ServletContext context, String filterName, String id, RateType type, int maxRequestsPerSecond) RateTracker(ServletContext context, String filterName, String id, int maxRequestsPerSecond)
{ {
_context = context; _context = context;
_filterName = filterName; _filterName = filterName;
_id = id; _id = id;
_type = type;
_maxRequestsPerSecond = maxRequestsPerSecond; _maxRequestsPerSecond = maxRequestsPerSecond;
_timestamps = new long[maxRequestsPerSecond]; _timestamps = new long[maxRequestsPerSecond];
_next = 0; _next = 0;
@ -1212,52 +1146,6 @@ public class DoSFilter implements Filter
return _id; return _id;
} }
public RateType getType()
{
return _type;
}
@Override
public void valueBound(HttpSessionBindingEvent event)
{
if (LOG.isDebugEnabled())
LOG.debug("Value bound: {}", getId());
_context = event.getSession().getServletContext();
}
@Override
public void valueUnbound(HttpSessionBindingEvent event)
{
//take the tracker out of the list of trackers
DoSFilter filter = (DoSFilter)event.getSession().getServletContext().getAttribute(_filterName);
removeFromRateTrackers(filter, _id);
_context = null;
}
@Override
public void sessionWillPassivate(HttpSessionEvent se)
{
//take the tracker of the list of trackers (if its still there)
DoSFilter filter = (DoSFilter)se.getSession().getServletContext().getAttribute(_filterName);
removeFromRateTrackers(filter, _id);
_context = null;
}
@Override
public void sessionDidActivate(HttpSessionEvent se)
{
RateTracker tracker = (RateTracker)se.getSession().getAttribute(__TRACKER);
ServletContext context = se.getSession().getServletContext();
tracker.setContext(context);
DoSFilter filter = (DoSFilter)context.getAttribute(_filterName);
if (filter == null)
{
LOG.info("No filter {} for rate tracker {}", _filterName, tracker);
return;
}
addToRateTrackers(filter, tracker);
}
public void setContext(ServletContext context) public void setContext(ServletContext context)
{ {
_context = context; _context = context;
@ -1309,7 +1197,7 @@ public class DoSFilter implements Filter
@Override @Override
public String toString() public String toString()
{ {
return "RateTracker/" + _id + "/" + _type; return "RateTracker/" + _id;
} }
public class Overage implements OverLimit public class Overage implements OverLimit
@ -1323,12 +1211,6 @@ public class DoSFilter implements Filter
this.count = count; this.count = count;
} }
@Override
public RateType getRateType()
{
return _type;
}
@Override @Override
public String getRateId() public String getRateId()
{ {
@ -1350,23 +1232,20 @@ public class DoSFilter implements Filter
@Override @Override
public String toString() public String toString()
{ {
final StringBuilder sb = new StringBuilder(OverLimit.class.getSimpleName()); return OverLimit.class.getSimpleName() + '@' + Integer.toHexString(hashCode()) +
sb.append('@').append(Integer.toHexString(hashCode())); "[id=" + getRateId() +
sb.append("[type=").append(getRateType()); ", duration=" + duration +
sb.append(", id=").append(getRateId()); ", count=" + count +
sb.append(", duration=").append(duration); ']';
sb.append(", count=").append(count);
sb.append(']');
return sb.toString();
} }
} }
} }
private static class FixedRateTracker extends RateTracker private static class FixedRateTracker extends RateTracker
{ {
public FixedRateTracker(ServletContext context, String filterName, String id, RateType type, int numRecentRequestsTracked) public FixedRateTracker(ServletContext context, String filterName, String id, int numRecentRequestsTracked)
{ {
super(context, filterName, id, type, numRecentRequestsTracked); super(context, filterName, id, numRecentRequestsTracked);
} }
@Override @Override
@ -1417,17 +1296,14 @@ public class DoSFilter implements Filter
private class DoSAsyncListener extends DoSTimeoutAsyncListener private class DoSAsyncListener extends DoSTimeoutAsyncListener
{ {
private final RateType priority; public DoSAsyncListener()
public DoSAsyncListener(RateType priority)
{ {
this.priority = priority;
} }
@Override @Override
public void onTimeout(AsyncEvent event) throws IOException public void onTimeout(AsyncEvent event) throws IOException
{ {
_queues.get(priority).remove(event.getAsyncContext()); _queue.remove(event.getAsyncContext()); // TODO what???
super.onTimeout(event); super.onTimeout(event);
} }
} }
@ -1475,8 +1351,6 @@ public class DoSFilter implements Filter
public interface OverLimit public interface OverLimit
{ {
RateType getRateType();
String getRateId(); String getRateId();
Duration getDuration(); Duration getDuration();
@ -1503,13 +1377,13 @@ public class DoSFilter implements Filter
switch (action) switch (action)
{ {
case REJECT: case REJECT:
LOG.warn("DOS ALERT: Request rejected ip={}, overlimit={}, session={}, user={}", request.getRemoteAddr(), overlimit, request.getRequestedSessionId(), request.getUserPrincipal()); LOG.warn("DoS ALERT: Request rejected ip={}, overlimit={}, user={}", request.getRemoteAddr(), overlimit, request.getUserPrincipal());
break; break;
case DELAY: case DELAY:
LOG.warn("DOS ALERT: Request delayed={}ms, ip={}, overlimit={}, session={}, user={}", dosFilter.getDelayMs(), request.getRemoteAddr(), overlimit, request.getRequestedSessionId(), request.getUserPrincipal()); LOG.warn("DoS ALERT: Request delayed={}ms, ip={}, overlimit={}, user={}", dosFilter.getDelayMs(), request.getRemoteAddr(), overlimit, request.getUserPrincipal());
break; break;
case THROTTLE: case THROTTLE:
LOG.warn("DOS ALERT: Request throttled ip={}, overlimit={}, session={}, user={}", request.getRemoteAddr(), overlimit, request.getRequestedSessionId(), request.getUserPrincipal()); LOG.warn("DoS ALERT: Request throttled ip={}, overlimit={}, user={}", request.getRemoteAddr(), overlimit, request.getUserPrincipal());
break; break;
} }

View File

@ -45,9 +45,7 @@ import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public abstract class AbstractDoSFilterTest public abstract class AbstractDoSFilterTest
{ {
@ -248,60 +246,6 @@ public abstract class AbstractDoSFilterTest
other.join(); other.join();
} }
@Test
public void testSessionTracking() throws Exception
{
// get a session, first
String requestSession = "GET /ctx/dos/test?session=true HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
String response = doRequests("", 1, 0, 0, requestSession);
String sessionId = response.substring(response.indexOf("Set-Cookie: ") + 12, response.indexOf(";"));
// all other requests use this session
String request = "GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nCookie: " + sessionId + "\r\n\r\n";
String last = "GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: " + sessionId + "\r\n\r\n";
String responses = doRequests(request + request + request + request + request, 2, 1100, 1100, last);
assertEquals(11, count(responses, "HTTP/1.1 200 OK"));
assertEquals(2, count(responses, "DoSFilter: delayed"));
}
@Test
public void testMultipleSessionTracking() throws Exception
{
// get some session ids, first
String requestSession = "GET /ctx/dos/test?session=true HTTP/1.1\r\nHost: localhost\r\n\r\n";
String closeRequest = "GET /ctx/dos/test?session=true HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
String response = doRequests(requestSession + requestSession, 1, 0, 0, closeRequest);
String[] sessions = response.split("\r\n\r\n");
String sessionId1 = sessions[0].substring(sessions[0].indexOf("Set-Cookie: ") + 12, sessions[0].indexOf(";"));
String sessionId2 = sessions[1].substring(sessions[1].indexOf("Set-Cookie: ") + 12, sessions[1].indexOf(";"));
// alternate between sessions
String request1 = "GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nCookie: " + sessionId1 + "\r\n\r\n";
String request2 = "GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nCookie: " + sessionId2 + "\r\n\r\n";
String last = "GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: " + sessionId2 + "\r\n\r\n";
// ensure the sessions are new
doRequests(request1 + request2, 1, 1100, 1100, last);
Thread.sleep(1000);
String responses = doRequests(request1 + request2 + request1 + request2 + request1, 2, 1100, 1100, last);
assertEquals(11, count(responses, "HTTP/1.1 200 OK"));
// This test is system speed dependent, so allow some (20%-ish) requests to be delayed, but not more.
assertThat("delayed count", count(responses, "DoSFilter: delayed"), lessThan(2));
// alternate between sessions
responses = doRequests(request1 + request2 + request1 + request2 + request1, 2, 250, 250, last);
// System.err.println(responses);
assertEquals(11, count(responses, "HTTP/1.1 200 OK"));
int delayedRequests = count(responses, "DoSFilter: delayed");
assertTrue(delayedRequests >= 2 && delayedRequests <= 5, "delayedRequests: " + delayedRequests + " is not between 2 and 5");
}
@Test @Test
public void testUnresponsiveClient() throws Exception public void testUnresponsiveClient() throws Exception
{ {

View File

@ -169,7 +169,7 @@ public class DoSFilterTest extends AbstractDoSFilterTest
{ {
boolean exceeded = false; boolean exceeded = false;
ServletContext context = new ContextHandler.StaticContext(); ServletContext context = new ContextHandler.StaticContext();
RateTracker rateTracker = new RateTracker(context, doSFilter.getName(), "test2", DoSFilter.RateType.UNKNOWN, 4); RateTracker rateTracker = new RateTracker(context, doSFilter.getName(), "test2", 4);
for (int i = 0; i < 5; i++) for (int i = 0; i < 5; i++)
{ {