411323 - DosFilter/QoSFilter should use AsyncContext rather than Continuations.

This commit is contained in:
Simone Bordet 2014-07-23 14:07:14 +02:00
parent 103cdbf6ef
commit 5956b9e013
4 changed files with 133 additions and 81 deletions

View File

@ -31,7 +31,9 @@ import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.servlet.AsyncContext;
import javax.servlet.AsyncEvent;
import javax.servlet.AsyncListener;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
@ -47,9 +49,7 @@ import javax.servlet.http.HttpSessionBindingEvent;
import javax.servlet.http.HttpSessionBindingListener;
import javax.servlet.http.HttpSessionEvent;
import org.eclipse.jetty.continuation.Continuation;
import org.eclipse.jetty.continuation.ContinuationListener;
import org.eclipse.jetty.continuation.ContinuationSupport;
import org.eclipse.jetty.server.handler.ContextHandler;
import org.eclipse.jetty.util.annotation.ManagedAttribute;
import org.eclipse.jetty.util.annotation.ManagedObject;
import org.eclipse.jetty.util.annotation.ManagedOperation;
@ -168,7 +168,10 @@ public class DoSFilter implements Filter
private static final int USER_IP = 1;
private static final int USER_UNKNOWN = 0;
private ServletContext _context;
private final String _suspended = "DoSFilter@" + Integer.toHexString(hashCode()) + ".SUSPENDED";
private final String _resumed = "DoSFilter@" + Integer.toHexString(hashCode()) + ".RESUMED";
private final ConcurrentHashMap<String, RateTracker> _rateTrackers = new ConcurrentHashMap<>();
private final List<String> _whitelist = new CopyOnWriteArrayList<>();
private volatile long _delayMs;
private volatile long _throttleMs;
private volatile long _maxWaitMs;
@ -181,34 +184,18 @@ public class DoSFilter implements Filter
private Semaphore _passes;
private volatile int _throttledRequests;
private volatile int _maxRequestsPerSec;
private Queue<Continuation>[] _queue;
private ContinuationListener[] _listeners;
private final ConcurrentHashMap<String, RateTracker> _rateTrackers = new ConcurrentHashMap<>();
private final List<String> _whitelist = new CopyOnWriteArrayList<>();
private Queue<AsyncContext>[] _queues;
private AsyncListener[] _listeners;
private Scheduler _scheduler;
public void init(FilterConfig filterConfig) throws ServletException
{
_context = filterConfig.getServletContext();
_queue = new Queue[getMaxPriority() + 1];
_listeners = new ContinuationListener[getMaxPriority() + 1];
for (int p = 0; p < _queue.length; p++)
_queues = new Queue[getMaxPriority() + 1];
_listeners = new AsyncListener[_queues.length];
for (int p = 0; p < _queues.length; p++)
{
_queue[p] = new ConcurrentLinkedQueue<>();
final int priority = p;
_listeners[p] = new ContinuationListener()
{
public void onComplete(Continuation continuation)
{
}
public void onTimeout(Continuation continuation)
{
_queue[priority].remove(continuation);
}
};
_queues[p] = new ConcurrentLinkedQueue<>();
_listeners[p] = new DoSAsyncListener(p);
}
_rateTrackers.clear();
@ -275,8 +262,9 @@ public class DoSFilter implements Filter
_scheduler = startScheduler();
if (_context != null && Boolean.parseBoolean(filterConfig.getInitParameter(MANAGED_ATTR_INIT_PARAM)))
_context.setAttribute(filterConfig.getFilterName(), this);
ServletContext context = filterConfig.getServletContext();
if (context != null && Boolean.parseBoolean(filterConfig.getInitParameter(MANAGED_ATTR_INIT_PARAM)))
context.setAttribute(filterConfig.getFilterName(), this);
}
protected Scheduler startScheduler() throws ServletException
@ -306,37 +294,40 @@ public class DoSFilter implements Filter
return;
}
// Look for the rate tracker for this request
// Look for the rate tracker for this request.
RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
if (tracker == null)
{
// This is the first time we have seen this request.
if (LOG.isDebugEnabled())
LOG.debug("Filtering {}", request);
// get a rate tracker associated with this request, and record one hit
// Get a rate tracker associated with this request, and record one hit.
tracker = getRateTracker(request);
// Calculate the rate and check it is over the allowed limit
final boolean overRateLimit = tracker.isRateExceeded(System.currentTimeMillis());
// pass it through if we are not currently over the rate limit
// Pass it through if we are not currently over the rate limit.
if (!overRateLimit)
{
if (LOG.isDebugEnabled())
LOG.debug("Allowing {}", request);
doFilterChain(filterChain, request, response);
return;
}
// We are over the limit.
// So either reject it, delay it or throttle it
// So either reject it, delay it or throttle it.
long delayMs = getDelayMs();
boolean insertHeaders = isInsertHeaders();
switch ((int)delayMs)
{
case -1:
{
// Reject this request
LOG.warn("DOS ALERT: Request rejected ip=" + request.getRemoteAddr() + ",session=" + request.getRequestedSessionId() + ",user=" + request.getUserPrincipal());
// Reject this request.
LOG.warn("DOS ALERT: Request rejected ip={}, session={}, user={}", request.getRemoteAddr(), request.getRequestedSessionId(), request.getUserPrincipal());
if (insertHeaders)
response.addHeader("DoSFilter", "unavailable");
response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
@ -344,39 +335,41 @@ public class DoSFilter implements Filter
}
case 0:
{
// fall through to throttle code
LOG.warn("DOS ALERT: Request throttled ip=" + request.getRemoteAddr() + ",session=" + request.getRequestedSessionId() + ",user=" + request.getUserPrincipal());
// Fall through to throttle the request.
LOG.warn("DOS ALERT: Request throttled ip={}, session={}, user={}", request.getRemoteAddr(), request.getRequestedSessionId(), request.getUserPrincipal());
request.setAttribute(__TRACKER, tracker);
break;
}
default:
{
// insert a delay before throttling the request
LOG.warn("DOS ALERT: Request delayed="+delayMs+"ms ip=" + request.getRemoteAddr() + ",session=" + request.getRequestedSessionId() + ",user=" + request.getUserPrincipal());
// Insert a delay before throttling the request,
// using the suspend+timeout mechanism of AsyncContext.
LOG.warn("DOS ALERT: Request delayed={}ms, ip={}, session={}, user={}", delayMs, request.getRemoteAddr(), request.getRequestedSessionId(), request.getUserPrincipal());
if (insertHeaders)
response.addHeader("DoSFilter", "delayed");
Continuation continuation = ContinuationSupport.getContinuation(request);
request.setAttribute(__TRACKER, tracker);
AsyncContext asyncContext = request.startAsync();
if (delayMs > 0)
continuation.setTimeout(delayMs);
continuation.suspend();
asyncContext.setTimeout(delayMs);
asyncContext.addListener(new DoSTimeoutAsyncListener());
return;
}
}
}
// Throttle the request
if (LOG.isDebugEnabled())
LOG.debug("Throttling {}", request);
// Throttle the request.
boolean accepted = false;
try
{
// check if we can afford to accept another request at this time
// Check if we can afford to accept another request at this time.
accepted = _passes.tryAcquire(getMaxWaitMs(), TimeUnit.MILLISECONDS);
if (!accepted)
{
// we were not accepted, so either we suspend to wait,or if we were woken up we insist or we fail
final Continuation continuation = ContinuationSupport.getContinuation(request);
// We were not accepted, so either we suspend to wait,
// or if we were woken up we insist or we fail.
Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
long throttleMs = getThrottleMs();
if (throttled != Boolean.TRUE && throttleMs > 0)
@ -385,30 +378,39 @@ public class DoSFilter implements Filter
request.setAttribute(__THROTTLED, Boolean.TRUE);
if (isInsertHeaders())
response.addHeader("DoSFilter", "throttled");
AsyncContext asyncContext = request.startAsync();
request.setAttribute(_suspended, Boolean.TRUE);
if (throttleMs > 0)
continuation.setTimeout(throttleMs);
continuation.suspend();
continuation.addContinuationListener(_listeners[priority]);
_queue[priority].add(continuation);
asyncContext.setTimeout(throttleMs);
asyncContext.addListener(_listeners[priority]);
_queues[priority].add(asyncContext);
if (LOG.isDebugEnabled())
LOG.debug("Throttled {}, {}ms", request, throttleMs);
return;
}
// else were we resumed?
else if (request.getAttribute("javax.servlet.resumed") == Boolean.TRUE)
Boolean resumed = (Boolean)request.getAttribute(_resumed);
if (resumed == Boolean.TRUE)
{
// we were resumed and somebody stole our pass, so we wait for the next one.
// We were resumed, we wait for the next pass.
_passes.acquire();
accepted = true;
}
}
// if we were accepted (either immediately or after throttle)
// If we were accepted (either immediately or after throttle)...
if (accepted)
// call the chain
{
// ...call the chain.
if (LOG.isDebugEnabled())
LOG.debug("Allowing {}", request);
doFilterChain(filterChain, request, response);
}
else
{
// fail the request
// ...otherwise fail the request.
if (LOG.isDebugEnabled())
LOG.debug("Rejecting {}", request);
if (isInsertHeaders())
response.addHeader("DoSFilter", "unavailable");
response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
@ -416,21 +418,28 @@ public class DoSFilter implements Filter
}
catch (InterruptedException e)
{
_context.log("DoS", e);
response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
}
finally
{
if (accepted)
{
// wake up the next highest priority request.
for (int p = _queue.length; p-- > 0; )
// Wake up the next highest priority request.
for (int p = _queues.length - 1; p >= 0; --p)
{
Continuation continuation = _queue[p].poll();
if (continuation != null && continuation.isSuspended())
AsyncContext asyncContext = _queues[p].poll();
if (asyncContext != null)
{
continuation.resume();
break;
ServletRequest candidate = asyncContext.getRequest();
Boolean suspended = (Boolean)candidate.getAttribute(_suspended);
if (suspended == Boolean.TRUE)
{
if (LOG.isDebugEnabled())
LOG.debug("Resuming {}", request);
candidate.setAttribute(_resumed, Boolean.TRUE);
asyncContext.dispatch();
break;
}
}
}
_passes.release();
@ -449,7 +458,6 @@ public class DoSFilter implements Filter
closeConnection(request, response, thread);
}
};
Scheduler.Task task = _scheduler.schedule(requestTimeout, getMaxRequestMs(), TimeUnit.MILLISECONDS);
try
{
@ -1056,10 +1064,10 @@ public class DoSFilter implements Filter
{
private static final long serialVersionUID = 3534663738034577872L;
transient protected final String _id;
transient protected final int _type;
transient protected final long[] _timestamps;
transient protected int _next;
protected transient final String _id;
protected transient final int _type;
protected transient final long[] _timestamps;
protected transient int _next;
public RateTracker(String id, int type, int maxRequestsPerSecond)
{
@ -1115,7 +1123,8 @@ public class DoSFilter implements Filter
//and ensure that we take ourselves out of the session so we are not saved
_rateTrackers.remove(_id);
se.getSession().removeAttribute(__TRACKER);
if (LOG.isDebugEnabled()) LOG.debug("Value removed: {}", getId());
if (LOG.isDebugEnabled())
LOG.debug("Value removed: {}", getId());
}
public void sessionDidActivate(HttpSessionEvent se)
@ -1171,4 +1180,45 @@ public class DoSFilter implements Filter
return "Fixed" + super.toString();
}
}
private class DoSTimeoutAsyncListener implements AsyncListener
{
@Override
public void onStartAsync(AsyncEvent event) throws IOException
{
}
@Override
public void onComplete(AsyncEvent event) throws IOException
{
}
@Override
public void onTimeout(AsyncEvent event) throws IOException
{
event.getAsyncContext().dispatch();
}
@Override
public void onError(AsyncEvent event) throws IOException
{
}
}
private class DoSAsyncListener extends DoSTimeoutAsyncListener
{
private final int priority;
public DoSAsyncListener(int priority)
{
this.priority = priority;
}
@Override
public void onTimeout(AsyncEvent event) throws IOException
{
_queues[priority].remove(event.getAsyncContext());
super.onTimeout(event);
}
}
}

View File

@ -362,8 +362,10 @@ public class QoSFilter implements Filter
public void onTimeout(AsyncEvent event) throws IOException
{
// Remove before it's redispatched, so it won't be
// redispatched again in the finally block below.
_queues[priority].remove(event.getAsyncContext());
// redispatched again at the end of the filtering.
AsyncContext asyncContext = event.getAsyncContext();
_queues[priority].remove(asyncContext);
asyncContext.dispatch();
}
@Override

View File

@ -18,12 +18,8 @@
package org.eclipse.jetty.servlets;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import java.util.ArrayList;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@ -34,6 +30,9 @@ import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
public class DoSFilterTest extends AbstractDoSFilterTest
{
private static final Logger LOG = Log.getLogger(DoSFilterTest.class);
@ -62,7 +61,7 @@ public class DoSFilterTest extends AbstractDoSFilterTest
}
@Test
public void isRateExceededTest() throws InterruptedException
public void testRateIsRateExceeded() throws InterruptedException
{
DoSFilter doSFilter = new DoSFilter();

View File

@ -3,3 +3,4 @@ org.eclipse.jetty.util.log.class=org.eclipse.jetty.util.log.StdErrLog
#org.eclipse.jetty.servlets.LEVEL=DEBUG
#org.eclipse.jetty.servlets.GzipFilter.LEVEL=DEBUG
#org.eclipse.jetty.servlets.QoSFilter.LEVEL=DEBUG
#org.eclipse.jetty.servlets.DoSFilter.LEVEL=DEBUG