[277798] Denial of Service Filter

git-svn-id: svn+ssh://dev.eclipse.org/svnroot/rt/org.eclipse.jetty/jetty/trunk@291 7e9141cc-0065-0410-87d8-b60c137991c4
This commit is contained in:
Athena Yao 2009-05-26 06:48:45 +00:00
parent 3e27e3bf9e
commit 60f784d2ac
5 changed files with 1142 additions and 0 deletions

View File

@ -4,6 +4,7 @@ jetty-7.0.0.M3-SNAPSHOT
+ added WebAppContext.setConfigurationDiscovered for servlet 3.0 features
+ 274251 Allow dispatch to welcome files that are servlets (configurable)
+ 277403 Cleanup system property usage.
+ 277798 Denial of Service Filter
+ Portable continuations for jetty6 and servlet3
jetty-7.0.0.M2 18 May 2009

View File

@ -0,0 +1,45 @@
// ========================================================================
// Copyright (c) 2009 Mort Bay Consulting Pty. Ltd.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
package org.eclipse.jetty.servlets;
import java.io.IOException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.eclipse.jetty.server.HttpConnection;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.util.log.Log;
/* ------------------------------------------------------------ */
/** Closeable DoS Filter.
* This is an extension to the {@link DoSFilter} that uses Jetty APIs to allow
* connections to be closed cleanly.
*/
public class CloseableDoSFilter extends DoSFilter
{
protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
{
try
{
Request base_request=(request instanceof Request)?(Request)request:HttpConnection.getCurrentConnection().getRequest();
base_request.getConnection().getEndPoint().close();
}
catch(IOException e)
{
Log.warn(e);
}
}
}

View File

@ -0,0 +1,75 @@
// ========================================================================
// Copyright (c) 2009 Mort Bay Consulting Pty. Ltd.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
package org.eclipse.jetty.servlets;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.eclipse.jetty.http.HttpURI;
import org.eclipse.jetty.servlet.FilterHolder;
import org.eclipse.jetty.testing.ServletTester;
import org.eclipse.jetty.util.log.Log;
public class CloseableDoSFilterTest extends DoSFilterTest
{
protected void setUp() throws Exception
{
_tester = new ServletTester();
HttpURI uri=new HttpURI(_tester.createSocketConnector(true));
_host=uri.getHost();
_port=uri.getPort();
_tester.setContextPath("/ctx");
_tester.addServlet(TestServlet.class, "/*");
FilterHolder dos=_tester.addFilter(CloseableDoSFilter2.class,"/dos/*",0);
dos.setInitParameter("maxRequestsPerSec","4");
dos.setInitParameter("delayMs","200");
dos.setInitParameter("throttledRequests","1");
dos.setInitParameter("waitMs","10");
dos.setInitParameter("throttleMs","4000");
dos.setInitParameter("remotePort", "false");
dos.setInitParameter("insertHeaders", "true");
FilterHolder quickTimeout = _tester.addFilter(CloseableDoSFilter2.class,"/timeout/*",0);
quickTimeout.setInitParameter("maxRequestsPerSec","4");
quickTimeout.setInitParameter("delayMs","200");
quickTimeout.setInitParameter("throttledRequests","1");
quickTimeout.setInitParameter("waitMs","10");
quickTimeout.setInitParameter("throttleMs","4000");
quickTimeout.setInitParameter("remotePort", "false");
quickTimeout.setInitParameter("insertHeaders", "true");
quickTimeout.setInitParameter("maxRequestMs", _maxRequestMs + "");
_tester.start();
}
public static class CloseableDoSFilter2 extends CloseableDoSFilter
{
public void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
{
try
{
response.getWriter().append("DoSFilter: timeout");
response.flushBuffer();
super.closeConnection(request,response,thread);
}
catch (Exception e)
{
Log.warn(e);
}
}
}
}

View File

@ -0,0 +1,676 @@
// ========================================================================
// Copyright (c) 2009 Mort Bay Consulting Pty. Ltd.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
package org.eclipse.jetty.servlets;
import java.io.IOException;
import java.util.HashSet;
import java.util.Queue;
import java.util.StringTokenizer;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import javax.servlet.http.HttpSessionBindingEvent;
import javax.servlet.http.HttpSessionBindingListener;
import org.eclipse.jetty.continuation.Continuation;
import org.eclipse.jetty.continuation.ContinuationSupport;
import org.eclipse.jetty.util.ArrayQueue;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.thread.Timeout;
/**
* Denial of Service filter
*
* <p>
* This filter is based on the {@link QoSFilter}. it is useful for limiting
* exposure to abuse from request flooding, whether malicious, or as a result of
* a misconfigured client.
* <p>
* The filter keeps track of the number of requests from a connection per
* second. If a limit is exceeded, the request is either rejected, delayed, or
* throttled.
* <p>
* When a request is throttled, it is placed in a priority queue. Priority is
* given first to authenticated users and users with an HttpSession, then
* connections which can be identified by their IP addresses. Connections with
* no way to identify them are given lowest priority.
* <p>
* The {@link #extractUserId(ServletRequest request)} function should be
* implemented, in order to uniquely identify authenticated users.
* <p>
* The following init parameters control the behavior of the filter:
*
* maxRequestsPerSec the maximum number of requests from a connection per
* second. Requests in excess of this are first delayed,
* then throttled.
*
* delayMs is the delay given to all requests over the rate limit,
* before they are considered at all. -1 means just reject request,
* 0 means no delay, otherwise it is the delay.
*
* maxWaitMs how long to blocking wait for the throttle semaphore.
*
* throttledRequests is the number of requests over the rate limit able to be
* considered at once.
*
* throttleMs how long to async wait for semaphore.
*
* maxRequestMs how long to allow this request to run.
*
* maxIdleTrackerMs how long to keep track of request rates for a connection,
* before deciding that the user has gone away, and discarding it
*
* insertHeaders if true , insert the DoSFilter headers into the response. Defaults to true.
*
* trackSessions if true, usage rate is tracked by session if a session exists. Defaults to true.
*
* remotePort if true and session tracking is not used, then rate is tracked by IP+port (effectively connection). Defaults to false.
*
* ipWhitelist a comma-separated list of IP addresses that will not be rate limited
*/
public class DoSFilter implements Filter
{
final static String __TRACKER = "DoSFilter.Tracker";
final static String __THROTTLED = "DoSFilter.Throttled";
final static int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
final static int __DEFAULT_DELAY_MS = 100;
final static int __DEFAULT_THROTTLE = 5;
final static int __DEFAULT_WAIT_MS=50;
final static long __DEFAULT_THROTTLE_MS = 30000L;
final static long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM=30000L;
final static long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM=30000L;
final static String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
final static String DELAY_MS_INIT_PARAM = "delayMs";
final static String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
final static String MAX_WAIT_INIT_PARAM="maxWaitMs";
final static String THROTTLE_MS_INIT_PARAM = "throttleMs";
final static String MAX_REQUEST_MS_INIT_PARAM="maxRequestMs";
final static String MAX_IDLE_TRACKER_MS_INIT_PARAM="maxIdleTrackerMs";
final static String INSERT_HEADERS_INIT_PARAM="insertHeaders";
final static String TRACK_SESSIONS_INIT_PARAM="trackSessions";
final static String REMOTE_PORT_INIT_PARAM="remotePort";
final static String IP_WHITELIST_INIT_PARAM="ipWhitelist";
final static int USER_AUTH = 2;
final static int USER_SESSION = 2;
final static int USER_IP = 1;
final static int USER_UNKNOWN = 0;
ServletContext _context;
protected long _delayMs;
protected long _throttleMs;
protected long _waitMs;
protected long _maxRequestMs;
protected long _maxIdleTrackerMs;
protected boolean _insertHeaders;
protected boolean _trackSessions;
protected boolean _remotePort;
protected Semaphore _passes;
protected Queue<Continuation>[] _queue;
protected int _maxRequestsPerSec;
protected final ConcurrentHashMap<String, RateTracker> _rateTrackers=new ConcurrentHashMap<String, RateTracker>();
private HashSet<String> _whitelist;
private final Timeout _requestTimeoutQ = new Timeout();
private final Timeout _trackerTimeoutQ = new Timeout();
private Thread _timerThread;
private volatile boolean _running;
public void init(FilterConfig filterConfig)
{
_context = filterConfig.getServletContext();
_queue = new Queue[getMaxPriority() + 1];
for (int p = 0; p < _queue.length; p++)
_queue[p] = new ArrayQueue<Continuation>();
int baseRateLimit = __DEFAULT_MAX_REQUESTS_PER_SEC;
if (filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM) != null)
baseRateLimit = Integer.parseInt(filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM));
_maxRequestsPerSec = baseRateLimit;
long delay = __DEFAULT_DELAY_MS;
if (filterConfig.getInitParameter(DELAY_MS_INIT_PARAM) != null)
delay = Integer.parseInt(filterConfig.getInitParameter(DELAY_MS_INIT_PARAM));
_delayMs = delay;
int passes = __DEFAULT_THROTTLE;
if (filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM) != null)
passes = Integer.parseInt(filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM));
_passes = new Semaphore(passes,true);
long wait = __DEFAULT_WAIT_MS;
if (filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM) != null)
wait = Integer.parseInt(filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM));
_waitMs = wait;
long suspend = __DEFAULT_THROTTLE_MS;
if (filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM) != null)
suspend = Integer.parseInt(filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM));
_throttleMs = suspend;
long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
if (filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM) != null )
maxRequestMs = Long.parseLong(filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM));
_maxRequestMs = maxRequestMs;
long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
if (filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM) != null )
maxIdleTrackerMs = Long.parseLong(filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM));
_maxIdleTrackerMs = maxIdleTrackerMs;
String whitelistString = "";
if (filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM) !=null )
whitelistString = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
// empty
if (whitelistString.length() == 0 )
_whitelist = new HashSet<String>();
else
{
StringTokenizer tokenizer = new StringTokenizer(whitelistString, ",");
_whitelist = new HashSet<String>(tokenizer.countTokens());
while (tokenizer.hasMoreTokens())
_whitelist.add(tokenizer.nextToken().trim());
Log.info("Whitelisted IP addresses: {}", _whitelist.toString());
}
String tmp = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
_insertHeaders = tmp==null || Boolean.parseBoolean(tmp);
tmp = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
_trackSessions = tmp==null || Boolean.parseBoolean(tmp);
tmp = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
_remotePort = tmp!=null&& Boolean.parseBoolean(tmp);
_requestTimeoutQ.setNow();
_requestTimeoutQ.setDuration(_maxRequestMs);
_trackerTimeoutQ.setNow();
_trackerTimeoutQ.setDuration(_maxIdleTrackerMs);
_running=true;
_timerThread = (new Thread()
{
public void run()
{
try
{
while (_running)
{
synchronized (_requestTimeoutQ)
{
_requestTimeoutQ.setNow();
_requestTimeoutQ.tick();
_trackerTimeoutQ.setNow(_requestTimeoutQ.getNow());
_trackerTimeoutQ.tick();
}
try
{
Thread.sleep(100);
}
catch (InterruptedException e)
{
Log.ignore(e);
}
}
}
finally
{
Log.info("DoSFilter timer exited");
}
}
});
_timerThread.start();
}
public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterchain) throws IOException, ServletException
{
final HttpServletRequest srequest = (HttpServletRequest)request;
final HttpServletResponse sresponse = (HttpServletResponse)response;
final long now=_requestTimeoutQ.getNow();
// Look for the rate tracker for this request
RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
if (tracker==null)
{
// This is the first time we have seen this request.
// get a rate tracker associated with this request, and record one hit
tracker = getRateTracker(request);
// Calculate the rate and check it is over the allowed limit
final boolean overRateLimit = tracker.isRateExceeded(now);
// pass it through if we are not currently over the rate limit
if (!overRateLimit)
{
doFilterChain(filterchain,srequest,sresponse);
return;
}
// We are over the limit.
Log.warn("DOS ALERT: ip="+srequest.getRemoteAddr()+",session="+srequest.getRequestedSessionId()+",user="+srequest.getUserPrincipal());
// So either reject it, delay it or throttle it
switch((int)_delayMs)
{
case -1:
{
// Reject this request
if (_insertHeaders)
((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
return;
}
case 0:
{
// fall through to throttle code
request.setAttribute(__TRACKER,tracker);
break;
}
default:
{
// insert a delay before throttling the request
if (_insertHeaders)
((HttpServletResponse)response).addHeader("DoSFilter","delayed");
Continuation continuation = ContinuationSupport.getContinuation(request,response);
request.setAttribute(__TRACKER,tracker);
if (_delayMs > 0)
continuation.setTimeout(_delayMs);
continuation.suspend();
return;
}
}
}
// Throttle the request
boolean accepted = false;
try
{
// check if we can afford to accept another request at this time
accepted = _passes.tryAcquire(_waitMs,TimeUnit.MILLISECONDS);
if (!accepted)
{
// we were not accepted, so either we suspend to wait,or if we were woken up we insist or we fail
final Continuation continuation = ContinuationSupport.getContinuation(request,response);
Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
if (throttled!=Boolean.TRUE && _throttleMs>0)
{
int priority = getPriority(request,tracker);
request.setAttribute(__THROTTLED,Boolean.TRUE);
if (_insertHeaders)
((HttpServletResponse)response).addHeader("DoSFilter","throttled");
if (_throttleMs > 0)
continuation.setTimeout(_throttleMs);
continuation.suspend();
_queue[priority].add(continuation);
return;
}
// else were we resumed?
else if (request.getAttribute("javax.servlet.resumed")==Boolean.TRUE)
{
// we were resumed and somebody stole our pass, so we wait for the next one.
_passes.acquire();
accepted = true;
}
}
// if we were accepted (either immediately or after throttle)
if (accepted)
// call the chain
doFilterChain(filterchain,srequest,sresponse);
else
{
// fail the request
if (_insertHeaders)
((HttpServletResponse)response).addHeader("DoSFilter","unavailable");
((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
}
}
catch (InterruptedException e)
{
_context.log("DoS",e);
((HttpServletResponse)response).sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
}
finally
{
if (accepted)
{
// wake up the next highest priority request.
synchronized (_queue)
{
for (int p = _queue.length; p-- > 0;)
{
Continuation continuation = _queue[p].poll();
if (continuation != null)
{
continuation.resume();
break;
}
}
}
_passes.release();
}
}
}
/**
* @param chain
* @param request
* @param response
* @throws IOException
* @throws ServletException
*/
protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response)
throws IOException, ServletException
{
final Thread thread=Thread.currentThread();
final Timeout.Task requestTimeout = new Timeout.Task()
{
public void expired()
{
closeConnection(request, response, thread);
}
};
try
{
synchronized (_requestTimeoutQ)
{
_requestTimeoutQ.schedule(requestTimeout);
}
chain.doFilter(request,response);
}
finally
{
synchronized (_requestTimeoutQ)
{
requestTimeout.cancel();
}
}
}
/**
* Takes drastic measures to return this response and stop this thread.
* Due to the way the connection is interrupted, may return mixed up headers.
* @param request current request
* @param response current response, which must be stopped
* @param thread the handling thread
*/
protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
{
// take drastic measures to return this response and stop this thread.
if( !response.isCommitted() )
{
response.setHeader("Connection", "close");
}
try
{
try
{
response.getWriter().close();
}
catch (IllegalStateException e)
{
response.getOutputStream().close();
}
}
catch (IOException e)
{
Log.warn(e);
}
// interrupt the handling thread
thread.interrupt();
}
/**
* Get priority for this request, based on user type
*
* @param request
* @param tracker
* @return priority
*/
protected int getPriority(ServletRequest request, RateTracker tracker)
{
if (extractUserId(request)!=null)
return USER_AUTH;
if (tracker!=null)
return tracker.getType();
return USER_UNKNOWN;
}
/**
* @return the maximum priority that we can assign to a request
*/
protected int getMaxPriority()
{
return USER_AUTH;
}
/**
* Return a request rate tracker associated with this connection; keeps
* track of this connection's request rate. If this is not the first request
* from this connection, return the existing object with the stored stats.
* If it is the first request, then create a new request tracker.
*
* Assumes that each connection has an identifying characteristic, and goes
* through them in order, taking the first that matches: user id (logged
* in), session id, client IP address. Unidentifiable connections are lumped
* into one.
*
* When a session expires, its rate tracker is automatically deleted.
*
* @param request
* @return the request rate tracker for the current connection
*/
public RateTracker getRateTracker(ServletRequest request)
{
HttpServletRequest srequest = (HttpServletRequest)request;
String loadId;
final int type;
loadId = extractUserId(request);
HttpSession session=srequest.getSession(false);
if (_trackSessions && session!=null && !session.isNew())
{
loadId=session.getId();
type = USER_SESSION;
}
else
{
loadId = _remotePort?(request.getRemoteAddr()+request.getRemotePort()):request.getRemoteAddr();
type = USER_IP;
}
RateTracker tracker=_rateTrackers.get(loadId);
if (tracker==null)
{
RateTracker t;
if (_whitelist.contains(request.getRemoteAddr()))
{
t = new FixedRateTracker(loadId,type,_maxRequestsPerSec);
}
else
{
t = new RateTracker(loadId,type,_maxRequestsPerSec);
}
tracker=_rateTrackers.putIfAbsent(loadId,t);
if (tracker==null)
tracker=t;
if (type == USER_IP)
{
// USER_IP expiration from _rateTrackers is handled by the _trackerTimeoutQ
synchronized (_trackerTimeoutQ)
{
_trackerTimeoutQ.schedule(tracker);
}
}
else if (session!=null)
// USER_SESSION expiration from _rateTrackers are handled by the HttpSessionBindingListener
session.setAttribute(__TRACKER,tracker);
}
return tracker;
}
public void destroy()
{
_running=false;
_timerThread.interrupt();
synchronized (_requestTimeoutQ)
{
_requestTimeoutQ.cancelAll();
_trackerTimeoutQ.cancelAll();
}
}
/**
* Returns the user id, used to track this connection.
* This SHOULD be overridden by subclasses.
*
* @param request
* @return a unique user id, if logged in; otherwise null.
*/
protected String extractUserId(ServletRequest request)
{
return null;
}
/**
* A RateTracker is associated with a connection, and stores request rate
* data.
*/
class RateTracker extends Timeout.Task implements HttpSessionBindingListener
{
protected final String _id;
protected final int _type;
protected final long[] _timestamps;
protected int _next;
public RateTracker(String id, int type,int maxRequestsPerSecond)
{
_id = id;
_type = type;
_timestamps=new long[maxRequestsPerSecond];
_next=0;
}
/**
* @return the current calculated request rate over the last second
*/
public boolean isRateExceeded(long now)
{
final long last;
synchronized (this)
{
last=_timestamps[_next];
_timestamps[_next]=now;
_next= (_next+1)%_timestamps.length;
}
boolean exceeded=last!=0 && (now-last)<1000L;
return exceeded;
}
public String getId()
{
return _id;
}
public int getType()
{
return _type;
}
public void valueBound(HttpSessionBindingEvent event)
{
}
public void valueUnbound(HttpSessionBindingEvent event)
{
_rateTrackers.remove(_id);
}
public void expired()
{
long now = _trackerTimeoutQ.getNow();
int latestIndex = _next == 0 ? 3 : (_next - 1 ) % _timestamps.length;
long last=_timestamps[latestIndex];
boolean hasRecentRequest = last != 0 && (now-last)<1000L;
if (hasRecentRequest)
reschedule();
else
_rateTrackers.remove(_id);
}
}
class FixedRateTracker extends RateTracker
{
public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
{
super(id,type,numRecentRequestsTracked);
}
public boolean isRateExceeded(long now)
{
// rate limit is never exceeded, but we keep track of the request timestamps
// so that we know whether there was recent activity on this tracker
// and whether it should be expired
synchronized (this)
{
_timestamps[_next]=now;
_next= (_next+1)%_timestamps.length;
}
return false;
}
}
}

View File

@ -0,0 +1,345 @@
// ========================================================================
// Copyright 2009 Mort Bay Consulting Pty. Ltd.
// ------------------------------------------------------------------------
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//========================================================================
package org.eclipse.jetty.servlets;
import java.io.IOException;
import java.net.Socket;
import javax.servlet.Servlet;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import junit.framework.TestCase;
import org.eclipse.jetty.http.HttpURI;
import org.eclipse.jetty.servlet.FilterHolder;
import org.eclipse.jetty.testing.ServletTester;
import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.util.log.Log;
public class DoSFilterTest extends TestCase
{
protected ServletTester _tester;
protected String _host;
protected int _port;
protected int _maxRequestMs = 200;
protected void setUp() throws Exception
{
_tester = new ServletTester();
HttpURI uri=new HttpURI(_tester.createChannelConnector(true));
_host=uri.getHost();
_port=uri.getPort();
_tester.setContextPath("/ctx");
_tester.addServlet(TestServlet.class, "/*");
FilterHolder dos=_tester.addFilter(DoSFilter2.class,"/dos/*",0);
dos.setInitParameter("maxRequestsPerSec","4");
dos.setInitParameter("delayMs","200");
dos.setInitParameter("throttledRequests","1");
dos.setInitParameter("waitMs","10");
dos.setInitParameter("throttleMs","4000");
dos.setInitParameter("remotePort", "false");
dos.setInitParameter("insertHeaders", "true");
FilterHolder quickTimeout = _tester.addFilter(DoSFilter2.class,"/timeout/*",0);
quickTimeout.setInitParameter("maxRequestsPerSec","4");
quickTimeout.setInitParameter("delayMs","200");
quickTimeout.setInitParameter("throttledRequests","1");
quickTimeout.setInitParameter("waitMs","10");
quickTimeout.setInitParameter("throttleMs","4000");
quickTimeout.setInitParameter("remotePort", "false");
quickTimeout.setInitParameter("insertHeaders", "true");
quickTimeout.setInitParameter("maxRequestMs", _maxRequestMs + "");
_tester.start();
}
protected void tearDown() throws Exception
{
_tester.stop();
}
private String doRequests(String requests, int loops, long pause0,long pause1,String request)
throws Exception
{
Socket socket = new Socket(_host,_port);
socket.setSoTimeout(30000);
for (int i=loops;i-->0;)
{
socket.getOutputStream().write(requests.getBytes("UTF-8"));
socket.getOutputStream().flush();
if (i>0 && pause0>0)
Thread.sleep(pause0);
}
if (pause1>0)
Thread.sleep(pause1);
socket.getOutputStream().write(request.getBytes("UTF-8"));
socket.getOutputStream().flush();
String response = "";
if (requests.contains("/unresponsive"))
{
// don't read in anything, forcing the request to time out
Thread.sleep(_maxRequestMs * 2);
response = IO.toString(socket.getInputStream(),"UTF-8");
}
else
{
response = IO.toString(socket.getInputStream(),"UTF-8");
}
socket.close();
return response;
}
private int count(String responses,String substring)
{
int count=0;
int i=responses.indexOf(substring);
while (i>=0)
{
count++;
i=responses.indexOf(substring,i+substring.length());
}
return count;
}
public void testEvenLowRateIP()
throws Exception
{
String request="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\n\r\n";
String last="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
String responses = doRequests(request,11,300,300,last);
assertEquals(12,count(responses,"HTTP/1.1 200 OK"));
assertEquals(0,count(responses,"DoSFilter:"));
}
public void testBurstLowRateIP()
throws Exception
{
String request="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\n\r\n";
String last="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
String responses = doRequests(request+request+request+request,2,1100,1100,last);
assertEquals(9,count(responses,"HTTP/1.1 200 OK"));
assertEquals(0,count(responses,"DoSFilter:"));
}
public void testDelayedIP()
throws Exception
{
String request="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\n\r\n";
String last="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
String responses = doRequests(request+request+request+request+request,2,1100,1100,last);
assertEquals(11,count(responses,"HTTP/1.1 200 OK"));
assertEquals(2,count(responses,"DoSFilter: delayed"));
}
public void testThrottledIP()
throws Exception
{
Thread other = new Thread()
{
public void run()
{
try
{
// Cause a delay, then sleep while holding pass
String request="GET /ctx/dos/sleeper HTTP/1.1\r\nHost: localhost\r\n\r\n";
String last="GET /ctx/dos/sleeper?sleep=2000 HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
String responses = doRequests(request+request+request+request,1,0,0,last);
}
catch(Exception e)
{
e.printStackTrace();
}
}
};
other.start();
Thread.sleep(1500);
String request="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\n\r\n";
String last="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
String responses = doRequests(request+request+request+request,1,0,0,last);
//System.out.println("responses are " + responses);
assertEquals(5,count(responses,"HTTP/1.1 200 OK"));
assertEquals(1,count(responses,"DoSFilter: delayed"));
assertEquals(1,count(responses,"DoSFilter: throttled"));
assertEquals(0,count(responses,"DoSFilter: unavailable"));
other.join();
}
public void testUnavailableIP()
throws Exception
{
Thread other = new Thread()
{
public void run()
{
try
{
// Cause a delay, then sleep while holding pass
String request="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\n\r\n";
String last="GET /ctx/dos/test?sleep=5000 HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
String responses = doRequests(request+request+request+request,1,0,0,last);
}
catch(Exception e)
{
e.printStackTrace();
}
}
};
other.start();
Thread.sleep(500);
String request="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\n\r\n";
String last="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
String responses = doRequests(request+request+request+request,1,0,0,last);
assertEquals(4,count(responses,"HTTP/1.1 200 OK"));
assertEquals(1,count(responses,"HTTP/1.1 503"));
assertEquals(1,count(responses,"DoSFilter: delayed"));
assertEquals(1,count(responses,"DoSFilter: throttled"));
assertEquals(1,count(responses,"DoSFilter: unavailable"));
other.join();
}
public void testSessionTracking()
throws Exception
{
// get a session, first
String requestSession="GET /ctx/dos/test?session=true HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
String response=doRequests("",1,0,0,requestSession);
String sessionId=response.substring(response.indexOf("Set-Cookie: ")+12, response.indexOf(";"));
// all other requests use this session
String request="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nCookie: " + sessionId + "\r\n\r\n";
String last="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: " + sessionId + "\r\n\r\n";
String responses = doRequests(request+request+request+request+request,2,1100,1100,last);
assertEquals(11,count(responses,"HTTP/1.1 200 OK"));
assertEquals(2,count(responses,"DoSFilter: delayed"));
}
public void testMultipleSessionTracking()
throws Exception
{
// get some session ids, first
String requestSession="GET /ctx/dos/test?session=true HTTP/1.1\r\nHost: localhost\r\n\r\n";
String closeRequest="GET /ctx/dos/test?session=true HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
String response=doRequests(requestSession+requestSession,1,0,0,closeRequest);
String[] sessions = response.split("\r\n\r\n");
String sessionId1=sessions[0].substring(sessions[0].indexOf("Set-Cookie: ")+12, sessions[0].indexOf(";"));
String sessionId2=sessions[1].substring(sessions[1].indexOf("Set-Cookie: ")+12, sessions[1].indexOf(";"));
// alternate between sessions
String request1="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nCookie: " + sessionId1 + "\r\n\r\n";
String request2="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nCookie: " + sessionId2 + "\r\n\r\n";
String last="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: " + sessionId2 + "\r\n\r\n";
String responses = doRequests(request1+request2+request1+request2+request1,2,1100,1100,last);
assertEquals(11,count(responses,"HTTP/1.1 200 OK"));
assertEquals(0,count(responses,"DoSFilter: delayed"));
// alternate between sessions
responses = doRequests(request1+request2+request1+request2+request1,2,550,550,last);
assertEquals(11,count(responses,"HTTP/1.1 200 OK"));
int delayedRequests = count(responses,"DoSFilter: delayed");
assertTrue(delayedRequests >= 2 && delayedRequests <= 3);
}
public void testUnresponsiveClient()
throws Exception
{
int numRequests = 1000;
String last="GET /ctx/timeout/unresponsive?lines="+numRequests+" HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
String responses = doRequests("",0,0,0,last);
// was expired, and stopped before reaching the end of the requests
int responseLines = count(responses, "Line:");
assertTrue(responses.contains("DoSFilter: timeout"));
assertTrue(responseLines > 0 && responseLines < numRequests);
}
public static class TestServlet extends HttpServlet implements Servlet
{
protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
{
if (request.getParameter("session")!=null)
request.getSession(true);
if (request.getParameter("sleep")!=null)
{
try
{
Thread.sleep(Long.parseLong(request.getParameter("sleep")));
}
catch(InterruptedException e)
{
}
}
if (request.getParameter("lines")!=null)
{
int count = Integer.parseInt(request.getParameter("lines"));
for(int i = 0; i < count; ++i)
{
response.getWriter().append("Line: " + i+"\n");
response.flushBuffer();
try
{
Thread.sleep(10);
}
catch(InterruptedException e)
{
}
}
}
response.setContentType("text/plain");
}
}
public static class DoSFilter2 extends DoSFilter
{
public void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
{
try {
response.getWriter().append("DoSFilter: timeout");
super.closeConnection(request,response,thread);
}
catch (Exception e)
{
Log.warn(e);
}
}
}
}