Merge remote-tracking branch 'origin/jetty-9.4.x' into jetty-10.0.x
This commit is contained in:
commit
4a0ffc2f73
|
@ -20,9 +20,13 @@ package org.eclipse.jetty.servlets;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.io.Serializable;
|
||||
import java.time.Duration;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Queue;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ConcurrentLinkedQueue;
|
||||
|
@ -162,10 +166,13 @@ public class DoSFilter implements Filter
|
|||
static final String ENABLED_INIT_PARAM = "enabled";
|
||||
static final String TOO_MANY_CODE = "tooManyCode";
|
||||
|
||||
private static final int USER_AUTH = 2;
|
||||
private static final int USER_SESSION = 2;
|
||||
private static final int USER_IP = 1;
|
||||
private static final int USER_UNKNOWN = 0;
|
||||
public enum RateType
|
||||
{
|
||||
AUTH,
|
||||
SESSION,
|
||||
IP,
|
||||
UNKNOWN
|
||||
}
|
||||
|
||||
private final String _suspended = "DoSFilter@" + Integer.toHexString(hashCode()) + ".SUSPENDED";
|
||||
private final String _resumed = "DoSFilter@" + Integer.toHexString(hashCode()) + ".RESUMED";
|
||||
|
@ -182,23 +189,22 @@ public class DoSFilter implements Filter
|
|||
private volatile boolean _remotePort;
|
||||
private volatile boolean _enabled;
|
||||
private volatile String _name;
|
||||
private DoSFilter.Listener _listener = new Listener();
|
||||
private Semaphore _passes;
|
||||
private volatile int _throttledRequests;
|
||||
private volatile int _maxRequestsPerSec;
|
||||
private Queue<AsyncContext>[] _queues;
|
||||
private AsyncListener[] _listeners;
|
||||
private Map<RateType, Queue<AsyncContext>> _queues = new HashMap<>();
|
||||
private Map<RateType, AsyncListener> _listeners = new HashMap<>();
|
||||
private Scheduler _scheduler;
|
||||
private ServletContext _context;
|
||||
|
||||
@Override
|
||||
public void init(FilterConfig filterConfig) throws ServletException
|
||||
{
|
||||
_queues = new Queue[getMaxPriority() + 1];
|
||||
_listeners = new AsyncListener[_queues.length];
|
||||
for (int p = 0; p < _queues.length; p++)
|
||||
for (RateType rateType : RateType.values())
|
||||
{
|
||||
_queues[p] = new ConcurrentLinkedQueue<>();
|
||||
_listeners[p] = new DoSAsyncListener(p);
|
||||
_queues.put(rateType, new ConcurrentLinkedQueue<>());
|
||||
_listeners.put(rateType, new DoSAsyncListener(rateType));
|
||||
}
|
||||
|
||||
_rateTrackers.clear();
|
||||
|
@ -306,67 +312,76 @@ public class DoSFilter implements Filter
|
|||
|
||||
// Look for the rate tracker for this request.
|
||||
RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
|
||||
if (tracker == null)
|
||||
if (tracker != null)
|
||||
{
|
||||
// This is the first time we have seen this request.
|
||||
if (LOG.isDebugEnabled())
|
||||
LOG.debug("Filtering {}", request);
|
||||
|
||||
// Get a rate tracker associated with this request, and record one hit.
|
||||
tracker = getRateTracker(request);
|
||||
|
||||
// Calculate the rate and check if it is over the allowed limit
|
||||
final boolean overRateLimit = tracker.isRateExceeded(System.currentTimeMillis());
|
||||
|
||||
// Pass it through if we are not currently over the rate limit.
|
||||
if (!overRateLimit)
|
||||
{
|
||||
if (LOG.isDebugEnabled())
|
||||
LOG.debug("Allowing {}", request);
|
||||
doFilterChain(filterChain, request, response);
|
||||
return;
|
||||
}
|
||||
|
||||
// We are over the limit.
|
||||
|
||||
// So either reject it, delay it or throttle it.
|
||||
long delayMs = getDelayMs();
|
||||
boolean insertHeaders = isInsertHeaders();
|
||||
switch ((int)delayMs)
|
||||
{
|
||||
case -1:
|
||||
{
|
||||
// Reject this request.
|
||||
LOG.warn("DOS ALERT: Request rejected ip={}, session={}, user={}", request.getRemoteAddr(), request.getRequestedSessionId(), request.getUserPrincipal());
|
||||
if (insertHeaders)
|
||||
response.addHeader("DoSFilter", "unavailable");
|
||||
response.sendError(getTooManyCode());
|
||||
return;
|
||||
}
|
||||
case 0:
|
||||
{
|
||||
// Fall through to throttle the request.
|
||||
LOG.warn("DOS ALERT: Request throttled ip={}, session={}, user={}", request.getRemoteAddr(), request.getRequestedSessionId(), request.getUserPrincipal());
|
||||
request.setAttribute(__TRACKER, tracker);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
// Insert a delay before throttling the request,
|
||||
// using the suspend+timeout mechanism of AsyncContext.
|
||||
LOG.warn("DOS ALERT: Request delayed={}ms, ip={}, session={}, user={}", delayMs, request.getRemoteAddr(), request.getRequestedSessionId(), request.getUserPrincipal());
|
||||
if (insertHeaders)
|
||||
response.addHeader("DoSFilter", "delayed");
|
||||
request.setAttribute(__TRACKER, tracker);
|
||||
AsyncContext asyncContext = request.startAsync();
|
||||
if (delayMs > 0)
|
||||
asyncContext.setTimeout(delayMs);
|
||||
asyncContext.addListener(new DoSTimeoutAsyncListener());
|
||||
return;
|
||||
}
|
||||
}
|
||||
// Redispatched, RateTracker present in request attributes.
|
||||
throttleRequest(request, response, filterChain, tracker);
|
||||
return;
|
||||
}
|
||||
|
||||
// This is the first time we have seen this request.
|
||||
if (LOG.isDebugEnabled())
|
||||
LOG.debug("Filtering {}", request);
|
||||
|
||||
// Get a rate tracker associated with this request, and record one hit.
|
||||
tracker = getRateTracker(request);
|
||||
|
||||
// Calculate the rate and check if it is over the allowed limit
|
||||
final OverLimit overLimit = tracker.isRateExceeded(System.currentTimeMillis());
|
||||
|
||||
// Pass it through if we are not currently over the rate limit.
|
||||
if (overLimit == null)
|
||||
{
|
||||
if (LOG.isDebugEnabled())
|
||||
LOG.debug("Allowing {}", request);
|
||||
doFilterChain(filterChain, request, response);
|
||||
return;
|
||||
}
|
||||
|
||||
// We are over the limit.
|
||||
|
||||
// Ask listener what to perform.
|
||||
Action action = _listener.onRequestOverLimit(request, overLimit, this);
|
||||
|
||||
// Perform action
|
||||
long delayMs = getDelayMs();
|
||||
boolean insertHeaders = isInsertHeaders();
|
||||
switch (action)
|
||||
{
|
||||
case NO_ACTION:
|
||||
if (LOG.isDebugEnabled())
|
||||
LOG.debug("Allowing over-limit request {}", request);
|
||||
doFilterChain(filterChain, request, response);
|
||||
break;
|
||||
case ABORT:
|
||||
if (LOG.isDebugEnabled())
|
||||
LOG.debug("Aborting over-limit request {}", request);
|
||||
response.sendError(-1);
|
||||
return;
|
||||
case REJECT:
|
||||
if (insertHeaders)
|
||||
response.addHeader("DoSFilter", "unavailable");
|
||||
response.sendError(getTooManyCode());
|
||||
return;
|
||||
case DELAY:
|
||||
// Insert a delay before throttling the request,
|
||||
// using the suspend+timeout mechanism of AsyncContext.
|
||||
if (insertHeaders)
|
||||
response.addHeader("DoSFilter", "delayed");
|
||||
request.setAttribute(__TRACKER, tracker);
|
||||
AsyncContext asyncContext = request.startAsync();
|
||||
if (delayMs > 0)
|
||||
asyncContext.setTimeout(delayMs);
|
||||
asyncContext.addListener(new DoSTimeoutAsyncListener());
|
||||
break;
|
||||
case THROTTLE:
|
||||
throttleRequest(request, response, filterChain, tracker);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
private void throttleRequest(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain, RateTracker tracker) throws IOException, ServletException
|
||||
{
|
||||
if (LOG.isDebugEnabled())
|
||||
LOG.debug("Throttling {}", request);
|
||||
|
||||
|
@ -384,15 +399,15 @@ public class DoSFilter implements Filter
|
|||
long throttleMs = getThrottleMs();
|
||||
if (!Boolean.TRUE.equals(throttled) && throttleMs > 0)
|
||||
{
|
||||
int priority = getPriority(request, tracker);
|
||||
RateType priority = getPriority(request, tracker);
|
||||
request.setAttribute(__THROTTLED, Boolean.TRUE);
|
||||
if (isInsertHeaders())
|
||||
response.addHeader("DoSFilter", "throttled");
|
||||
AsyncContext asyncContext = request.startAsync();
|
||||
request.setAttribute(_suspended, Boolean.TRUE);
|
||||
asyncContext.setTimeout(throttleMs);
|
||||
asyncContext.addListener(_listeners[priority]);
|
||||
_queues[priority].add(asyncContext);
|
||||
asyncContext.addListener(_listeners.get(priority));
|
||||
_queues.get(priority).add(asyncContext);
|
||||
if (LOG.isDebugEnabled())
|
||||
LOG.debug("Throttled {}, {}ms", request, throttleMs);
|
||||
return;
|
||||
|
@ -437,9 +452,9 @@ public class DoSFilter implements Filter
|
|||
try
|
||||
{
|
||||
// Wake up the next highest priority request.
|
||||
for (int p = _queues.length - 1; p >= 0; --p)
|
||||
for (RateType rateType : RateType.values())
|
||||
{
|
||||
AsyncContext asyncContext = _queues[p].poll();
|
||||
AsyncContext asyncContext = _queues.get(rateType).poll();
|
||||
if (asyncContext != null)
|
||||
{
|
||||
ServletRequest candidate = asyncContext.getRequest();
|
||||
|
@ -519,21 +534,31 @@ public class DoSFilter implements Filter
|
|||
* @param tracker the rate tracker for this request
|
||||
* @return the priority for this request
|
||||
*/
|
||||
private int getPriority(HttpServletRequest request, RateTracker tracker)
|
||||
private RateType getPriority(HttpServletRequest request, RateTracker tracker)
|
||||
{
|
||||
if (extractUserId(request) != null)
|
||||
return USER_AUTH;
|
||||
return RateType.AUTH;
|
||||
if (tracker != null)
|
||||
return tracker.getType();
|
||||
return USER_UNKNOWN;
|
||||
return RateType.UNKNOWN;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the maximum priority that we can assign to a request
|
||||
*/
|
||||
protected int getMaxPriority()
|
||||
protected RateType getMaxPriority()
|
||||
{
|
||||
return USER_AUTH;
|
||||
return RateType.AUTH;
|
||||
}
|
||||
|
||||
public void setListener(DoSFilter.Listener listener)
|
||||
{
|
||||
_listener = Objects.requireNonNull(listener, "Listener may not be null");
|
||||
}
|
||||
|
||||
public DoSFilter.Listener getListener()
|
||||
{
|
||||
return _listener;
|
||||
}
|
||||
|
||||
private void schedule(RateTracker tracker)
|
||||
|
@ -562,22 +587,22 @@ public class DoSFilter implements Filter
|
|||
HttpSession session = ((HttpServletRequest)request).getSession(false);
|
||||
|
||||
String loadId = extractUserId(request);
|
||||
final int type;
|
||||
final RateType type;
|
||||
if (loadId != null)
|
||||
{
|
||||
type = USER_AUTH;
|
||||
type = RateType.AUTH;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (isTrackSessions() && session != null && !session.isNew())
|
||||
{
|
||||
loadId = session.getId();
|
||||
type = USER_SESSION;
|
||||
type = RateType.SESSION;
|
||||
}
|
||||
else
|
||||
{
|
||||
loadId = isRemotePort() ? createRemotePortId(request) : request.getRemoteAddr();
|
||||
type = USER_IP;
|
||||
type = RateType.IP;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -594,7 +619,7 @@ public class DoSFilter implements Filter
|
|||
if (existing != null)
|
||||
tracker = existing;
|
||||
|
||||
if (type == USER_IP)
|
||||
if (type == RateType.IP)
|
||||
{
|
||||
// USER_IP expiration from _rateTrackers is handled by the _scheduler
|
||||
_scheduler.schedule(tracker, getMaxIdleTrackerMs(), TimeUnit.MILLISECONDS);
|
||||
|
@ -1032,6 +1057,11 @@ public class DoSFilter implements Filter
|
|||
_enabled = enabled;
|
||||
}
|
||||
|
||||
/**
|
||||
* Status code for Rejected for too many requests.
|
||||
*
|
||||
* @return the configured status code (default: 429 - Too Many Requests)
|
||||
*/
|
||||
public int getTooManyCode()
|
||||
{
|
||||
return _tooManyCode;
|
||||
|
@ -1120,6 +1150,13 @@ public class DoSFilter implements Filter
|
|||
return _whitelist.remove(address);
|
||||
}
|
||||
|
||||
private String createRemotePortId(ServletRequest request)
|
||||
{
|
||||
String addr = request.getRemoteAddr();
|
||||
int port = request.getRemotePort();
|
||||
return addr + ":" + port;
|
||||
}
|
||||
|
||||
/**
|
||||
* A RateTracker is associated with a connection, and stores request rate
|
||||
* data.
|
||||
|
@ -1132,17 +1169,19 @@ public class DoSFilter implements Filter
|
|||
protected final String _filterName;
|
||||
protected transient ServletContext _context;
|
||||
protected final String _id;
|
||||
protected final int _type;
|
||||
protected final RateType _type;
|
||||
protected final int _maxRequestsPerSecond;
|
||||
protected final long[] _timestamps;
|
||||
|
||||
protected int _next;
|
||||
|
||||
public RateTracker(ServletContext context, String filterName, String id, int type, int maxRequestsPerSecond)
|
||||
public RateTracker(ServletContext context, String filterName, String id, RateType type, int maxRequestsPerSecond)
|
||||
{
|
||||
_context = context;
|
||||
_filterName = filterName;
|
||||
_id = id;
|
||||
_type = type;
|
||||
_maxRequestsPerSecond = maxRequestsPerSecond;
|
||||
_timestamps = new long[maxRequestsPerSecond];
|
||||
_next = 0;
|
||||
}
|
||||
|
@ -1151,7 +1190,7 @@ public class DoSFilter implements Filter
|
|||
* @param now the time now (in milliseconds)
|
||||
* @return the current calculated request rate over the last second
|
||||
*/
|
||||
public boolean isRateExceeded(long now)
|
||||
public OverLimit isRateExceeded(long now)
|
||||
{
|
||||
final long last;
|
||||
try (AutoLock l = _lock.lock())
|
||||
|
@ -1161,7 +1200,17 @@ public class DoSFilter implements Filter
|
|||
_next = (_next + 1) % _timestamps.length;
|
||||
}
|
||||
|
||||
return last != 0 && (now - last) < 1000L;
|
||||
if (last == 0)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
long rate = (now - last);
|
||||
if (rate < 1000L)
|
||||
{
|
||||
return new Overage(Duration.ofMillis(rate), _maxRequestsPerSecond);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public String getId()
|
||||
|
@ -1169,7 +1218,7 @@ public class DoSFilter implements Filter
|
|||
return _id;
|
||||
}
|
||||
|
||||
public int getType()
|
||||
public RateType getType()
|
||||
{
|
||||
return _type;
|
||||
}
|
||||
|
@ -1242,7 +1291,7 @@ public class DoSFilter implements Filter
|
|||
{
|
||||
if (_context == null)
|
||||
{
|
||||
LOG.warn("Unknkown context for rate tracker {}", this);
|
||||
LOG.warn("Unknown context for rate tracker {}", this);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1268,17 +1317,66 @@ public class DoSFilter implements Filter
|
|||
{
|
||||
return "RateTracker/" + _id + "/" + _type;
|
||||
}
|
||||
|
||||
public class Overage implements OverLimit
|
||||
{
|
||||
private final Duration duration;
|
||||
private final long count;
|
||||
|
||||
public Overage(Duration dur, long count)
|
||||
{
|
||||
this.duration = dur;
|
||||
this.count = count;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RateType getRateType()
|
||||
{
|
||||
return _type;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getRateId()
|
||||
{
|
||||
return _id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Duration getDuration()
|
||||
{
|
||||
return duration;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getCount()
|
||||
{
|
||||
return count;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString()
|
||||
{
|
||||
final StringBuilder sb = new StringBuilder(OverLimit.class.getSimpleName());
|
||||
sb.append('@').append(Integer.toHexString(hashCode()));
|
||||
sb.append("[type=").append(getRateType());
|
||||
sb.append(", id=").append(getRateId());
|
||||
sb.append(", duration=").append(duration);
|
||||
sb.append(", count=").append(count);
|
||||
sb.append(']');
|
||||
return sb.toString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static class FixedRateTracker extends RateTracker
|
||||
{
|
||||
public FixedRateTracker(ServletContext context, String filterName, String id, int type, int numRecentRequestsTracked)
|
||||
public FixedRateTracker(ServletContext context, String filterName, String id, RateType type, int numRecentRequestsTracked)
|
||||
{
|
||||
super(context, filterName, id, type, numRecentRequestsTracked);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isRateExceeded(long now)
|
||||
public OverLimit isRateExceeded(long now)
|
||||
{
|
||||
// rate limit is never exceeded, but we keep track of the request timestamps
|
||||
// so that we know whether there was recent activity on this tracker
|
||||
|
@ -1289,7 +1387,7 @@ public class DoSFilter implements Filter
|
|||
_next = (_next + 1) % _timestamps.length;
|
||||
}
|
||||
|
||||
return false;
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -1325,9 +1423,9 @@ public class DoSFilter implements Filter
|
|||
|
||||
private class DoSAsyncListener extends DoSTimeoutAsyncListener
|
||||
{
|
||||
private final int priority;
|
||||
private final RateType priority;
|
||||
|
||||
public DoSAsyncListener(int priority)
|
||||
public DoSAsyncListener(RateType priority)
|
||||
{
|
||||
this.priority = priority;
|
||||
}
|
||||
|
@ -1335,15 +1433,93 @@ public class DoSFilter implements Filter
|
|||
@Override
|
||||
public void onTimeout(AsyncEvent event) throws IOException
|
||||
{
|
||||
_queues[priority].remove(event.getAsyncContext());
|
||||
_queues.get(priority).remove(event.getAsyncContext());
|
||||
super.onTimeout(event);
|
||||
}
|
||||
}
|
||||
|
||||
private String createRemotePortId(ServletRequest request)
|
||||
public enum Action
|
||||
{
|
||||
String addr = request.getRemoteAddr();
|
||||
int port = request.getRemotePort();
|
||||
return addr + ":" + port;
|
||||
/**
|
||||
* No action is taken against the Request, it is allowed to be processed normally.
|
||||
*/
|
||||
NO_ACTION,
|
||||
/**
|
||||
* The request and response is aborted, no response is sent.
|
||||
*/
|
||||
ABORT,
|
||||
/**
|
||||
* The request is rejected by sending an error based on {@link DoSFilter#getTooManyCode()}
|
||||
*/
|
||||
REJECT,
|
||||
/**
|
||||
* The request is delayed based on {@link DoSFilter#getDelayMs()}
|
||||
*/
|
||||
DELAY,
|
||||
/**
|
||||
* The request is throttled.
|
||||
*/
|
||||
THROTTLE;
|
||||
|
||||
/**
|
||||
* Obtain the Action based on configured {@link DoSFilter#getDelayMs()}
|
||||
*
|
||||
* @param delayMs the delay in milliseconds.
|
||||
* @return the Action proposed.
|
||||
*/
|
||||
public static Action fromDelay(long delayMs)
|
||||
{
|
||||
if (delayMs < 0)
|
||||
return Action.REJECT;
|
||||
|
||||
if (delayMs == 0)
|
||||
return Action.THROTTLE;
|
||||
|
||||
return Action.DELAY;
|
||||
}
|
||||
}
|
||||
|
||||
public interface OverLimit
|
||||
{
|
||||
RateType getRateType();
|
||||
|
||||
String getRateId();
|
||||
|
||||
Duration getDuration();
|
||||
|
||||
long getCount();
|
||||
}
|
||||
|
||||
/**
|
||||
* Listener for actions taken against specific requests.
|
||||
*/
|
||||
public static class Listener
|
||||
{
|
||||
/**
|
||||
* Process the onRequestOverLimit() behavior.
|
||||
*
|
||||
* @param request the request that is over the limit
|
||||
* @param dosFilter the {@link DoSFilter} that this event occurred on
|
||||
* @return the action to actually perform.
|
||||
*/
|
||||
public Action onRequestOverLimit(HttpServletRequest request, OverLimit overlimit, DoSFilter dosFilter)
|
||||
{
|
||||
Action action = Action.fromDelay(dosFilter.getDelayMs());
|
||||
|
||||
switch (action)
|
||||
{
|
||||
case REJECT:
|
||||
LOG.warn("DOS ALERT: Request rejected ip={}, overlimit={}, session={}, user={}", request.getRemoteAddr(), overlimit, request.getRequestedSessionId(), request.getUserPrincipal());
|
||||
break;
|
||||
case DELAY:
|
||||
LOG.warn("DOS ALERT: Request delayed={}ms, ip={}, overlimit={}, session={}, user={}", dosFilter.getDelayMs(), request.getRemoteAddr(), overlimit, request.getRequestedSessionId(), request.getUserPrincipal());
|
||||
break;
|
||||
case THROTTLE:
|
||||
LOG.warn("DOS ALERT: Request throttled ip={}, overlimit={}, session={}, user={}", request.getRemoteAddr(), overlimit, request.getRequestedSessionId(), request.getUserPrincipal());
|
||||
break;
|
||||
}
|
||||
|
||||
return action;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -174,12 +174,12 @@ public class DoSFilterTest extends AbstractDoSFilterTest
|
|||
{
|
||||
boolean exceeded = false;
|
||||
ServletContext context = new ContextHandler.StaticContext();
|
||||
RateTracker rateTracker = new RateTracker(context, doSFilter.getName(), "test2", 0, 4);
|
||||
RateTracker rateTracker = new RateTracker(context, doSFilter.getName(), "test2", DoSFilter.RateType.UNKNOWN, 4);
|
||||
|
||||
for (int i = 0; i < 5; i++)
|
||||
{
|
||||
Thread.sleep(sleep);
|
||||
if (rateTracker.isRateExceeded(TimeUnit.NANOSECONDS.toMillis(System.nanoTime())))
|
||||
if (rateTracker.isRateExceeded(TimeUnit.NANOSECONDS.toMillis(System.nanoTime())) != null)
|
||||
exceeded = true;
|
||||
}
|
||||
return exceeded;
|
||||
|
|
Loading…
Reference in New Issue