SEC-29: Save POST parameters on AuthenticationEntryPoint redirect.

This commit is contained in:
Ben Alex 2006-04-28 05:05:35 +00:00
parent 2d6813d354
commit d125569bd6
16 changed files with 1475 additions and 116 deletions

View File

@ -1,4 +1,4 @@
/* Copyright 2004, 2005 Acegi Technology Pty Limited /* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,6 +15,8 @@
package org.acegisecurity.intercept.web; package org.acegisecurity.intercept.web;
import org.acegisecurity.util.UrlUtils;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
import javax.servlet.ServletRequest; import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse; import javax.servlet.ServletResponse;
@ -88,10 +90,7 @@ public class FilterInvocation {
* @return the full URL of this request * @return the full URL of this request
*/ */
public String getFullRequestUrl() { public String getFullRequestUrl() {
return getHttpRequest().getScheme() + "://" return UrlUtils.getFullRequestUrl(this);
+ getHttpRequest().getServerName() + ":"
+ getHttpRequest().getServerPort() + getHttpRequest().getContextPath()
+ getRequestUrl();
} }
public HttpServletRequest getHttpRequest() { public HttpServletRequest getHttpRequest() {
@ -106,19 +105,13 @@ public class FilterInvocation {
return request; return request;
} }
/**
* Obtains the web application-specific fragment of the URL.
*
* @return the URL, excluding any server name, context path or servlet path
*/
public String getRequestUrl() { public String getRequestUrl() {
String pathInfo = getHttpRequest().getPathInfo(); return UrlUtils.getRequestUrl(this);
String queryString = getHttpRequest().getQueryString();
String uri = getHttpRequest().getServletPath();
if (uri == null) {
uri = getHttpRequest().getRequestURI();
uri = uri.substring(getHttpRequest().getContextPath().length());
}
return uri + ((pathInfo == null) ? "" : pathInfo)
+ ((queryString == null) ? "" : ("?" + queryString));
} }
public ServletResponse getResponse() { public ServletResponse getResponse() {

View File

@ -1,4 +1,4 @@
/* Copyright 2004 Acegi Technology Pty Limited /* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -17,6 +17,7 @@ package org.acegisecurity.securechannel;
import org.acegisecurity.ConfigAttribute; import org.acegisecurity.ConfigAttribute;
import org.acegisecurity.ConfigAttributeDefinition; import org.acegisecurity.ConfigAttributeDefinition;
import org.acegisecurity.intercept.web.FilterInvocation; import org.acegisecurity.intercept.web.FilterInvocation;
import org.acegisecurity.intercept.web.FilterInvocationDefinitionSource; import org.acegisecurity.intercept.web.FilterInvocationDefinitionSource;
@ -24,6 +25,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.InitializingBean;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import java.io.IOException; import java.io.IOException;
@ -78,34 +80,19 @@ public class ChannelProcessingFilter implements InitializingBean, Filter {
//~ Methods ================================================================ //~ Methods ================================================================
public void setChannelDecisionManager(
ChannelDecisionManager channelDecisionManager) {
this.channelDecisionManager = channelDecisionManager;
}
public ChannelDecisionManager getChannelDecisionManager() {
return channelDecisionManager;
}
public void setFilterInvocationDefinitionSource(
FilterInvocationDefinitionSource filterInvocationDefinitionSource) {
this.filterInvocationDefinitionSource = filterInvocationDefinitionSource;
}
public FilterInvocationDefinitionSource getFilterInvocationDefinitionSource() {
return filterInvocationDefinitionSource;
}
public void afterPropertiesSet() throws Exception { public void afterPropertiesSet() throws Exception {
Assert.notNull(filterInvocationDefinitionSource, "filterInvocationDefinitionSource must be specified"); Assert.notNull(filterInvocationDefinitionSource,
Assert.notNull(channelDecisionManager, "channelDecisionManager must be specified"); "filterInvocationDefinitionSource must be specified");
Assert.notNull(channelDecisionManager,
"channelDecisionManager must be specified");
Iterator iter = this.filterInvocationDefinitionSource Iterator iter = this.filterInvocationDefinitionSource
.getConfigAttributeDefinitions(); .getConfigAttributeDefinitions();
if (iter == null) { if (iter == null) {
if (logger.isWarnEnabled()) { if (logger.isWarnEnabled()) {
logger.warn("Could not validate configuration attributes as the FilterInvocationDefinitionSource did not return a ConfigAttributeDefinition Iterator"); logger.warn(
"Could not validate configuration attributes as the FilterInvocationDefinitionSource did not return a ConfigAttributeDefinition Iterator");
} }
return; return;
@ -115,7 +102,7 @@ public class ChannelProcessingFilter implements InitializingBean, Filter {
while (iter.hasNext()) { while (iter.hasNext()) {
ConfigAttributeDefinition def = (ConfigAttributeDefinition) iter ConfigAttributeDefinition def = (ConfigAttributeDefinition) iter
.next(); .next();
Iterator attributes = def.getConfigAttributes(); Iterator attributes = def.getConfigAttributes();
while (attributes.hasNext()) { while (attributes.hasNext()) {
@ -132,7 +119,8 @@ public class ChannelProcessingFilter implements InitializingBean, Filter {
logger.info("Validated configuration attributes"); logger.info("Validated configuration attributes");
} }
} else { } else {
throw new IllegalArgumentException("Unsupported configuration attributes: " + set.toString()); throw new IllegalArgumentException(
"Unsupported configuration attributes: " + set.toString());
} }
} }
@ -154,7 +142,7 @@ public class ChannelProcessingFilter implements InitializingBean, Filter {
if (attr != null) { if (attr != null) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Request: " + fi.getFullRequestUrl() logger.debug("Request: " + fi.toString()
+ "; ConfigAttributes: " + attr.toString()); + "; ConfigAttributes: " + attr.toString());
} }
@ -168,5 +156,23 @@ public class ChannelProcessingFilter implements InitializingBean, Filter {
chain.doFilter(request, response); chain.doFilter(request, response);
} }
public ChannelDecisionManager getChannelDecisionManager() {
return channelDecisionManager;
}
public FilterInvocationDefinitionSource getFilterInvocationDefinitionSource() {
return filterInvocationDefinitionSource;
}
public void init(FilterConfig filterConfig) throws ServletException {} public void init(FilterConfig filterConfig) throws ServletException {}
public void setChannelDecisionManager(
ChannelDecisionManager channelDecisionManager) {
this.channelDecisionManager = channelDecisionManager;
}
public void setFilterInvocationDefinitionSource(
FilterInvocationDefinitionSource filterInvocationDefinitionSource) {
this.filterInvocationDefinitionSource = filterInvocationDefinitionSource;
}
} }

View File

@ -26,6 +26,7 @@ import org.acegisecurity.event.authentication.InteractiveAuthenticationSuccessEv
import org.acegisecurity.ui.rememberme.NullRememberMeServices; import org.acegisecurity.ui.rememberme.NullRememberMeServices;
import org.acegisecurity.ui.rememberme.RememberMeServices; import org.acegisecurity.ui.rememberme.RememberMeServices;
import org.acegisecurity.ui.savedrequest.SavedRequest;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
@ -78,10 +79,12 @@ import javax.servlet.http.HttpServletResponse;
* <li> * <li>
* <code>defaultTargetUrl</code> indicates the URL that should be used for * <code>defaultTargetUrl</code> indicates the URL that should be used for
* redirection if the <code>HttpSession</code> attribute named {@link * redirection if the <code>HttpSession</code> attribute named {@link
* #ACEGI_SECURITY_TARGET_URL_KEY} does not indicate the target URL once * #ACEGI_SAVED_REQUEST_KEY} does not indicate the target URL once
* authentication is completed successfully. eg: <code>/</code>. This will be * authentication is completed successfully. eg: <code>/</code>. The
* treated as relative to the web-app's context path, and should include the * <code>defaultTargetUrl</code> will be treated as relative to the web-app's
* leading <code>/</code>. * context path, and should include the leading <code>/</code>. Alternatively,
* inclusion of a scheme name (eg http:// or https://) as the prefix will
* denote a fully-qualified URL and this is also supported.
* </li> * </li>
* <li> * <li>
* <code>authenticationFailureUrl</code> indicates the URL that should be used * <code>authenticationFailureUrl</code> indicates the URL that should be used
@ -95,8 +98,8 @@ import javax.servlet.http.HttpServletResponse;
* <li> * <li>
* <code>alwaysUseDefaultTargetUrl</code> causes successful authentication to * <code>alwaysUseDefaultTargetUrl</code> causes successful authentication to
* always redirect to the <code>defaultTargetUrl</code>, even if the * always redirect to the <code>defaultTargetUrl</code>, even if the
* <code>HttpSession</code> attribute named {@link * <code>HttpSession</code> attribute named {@link #ACEGI_SAVED_REQUEST_KEY}
* #ACEGI_SECURITY_TARGET_URL_KEY} defines the intended target URL. * defines the intended target URL.
* </li> * </li>
* </ul> * </ul>
* *
@ -132,12 +135,15 @@ import javax.servlet.http.HttpServletResponse;
* recorded via an <code>AuthenticationManager</code>-specific application * recorded via an <code>AuthenticationManager</code>-specific application
* event. * event.
* </p> * </p>
*
* @author Ben Alex
* @version $Id$
*/ */
public abstract class AbstractProcessingFilter implements Filter, public abstract class AbstractProcessingFilter implements Filter,
InitializingBean, ApplicationEventPublisherAware, MessageSourceAware { InitializingBean, ApplicationEventPublisherAware, MessageSourceAware {
//~ Static fields/initializers ============================================= //~ Static fields/initializers =============================================
public static final String ACEGI_SECURITY_TARGET_URL_KEY = "ACEGI_SECURITY_TARGET_URL"; public static final String ACEGI_SAVED_REQUEST_KEY = "ACEGI_SAVED_REQUEST_KEY";
public static final String ACEGI_SECURITY_LAST_EXCEPTION_KEY = "ACEGI_SECURITY_LAST_EXCEPTION"; public static final String ACEGI_SECURITY_LAST_EXCEPTION_KEY = "ACEGI_SECURITY_LAST_EXCEPTION";
//~ Instance fields ======================================================== //~ Instance fields ========================================================
@ -303,6 +309,13 @@ public abstract class AbstractProcessingFilter implements Filter,
return continueChainBeforeSuccessfulAuthentication; return continueChainBeforeSuccessfulAuthentication;
} }
public static String obtainFullRequestUrl(HttpServletRequest request) {
SavedRequest savedRequest = (SavedRequest) request.getSession()
.getAttribute(AbstractProcessingFilter.ACEGI_SAVED_REQUEST_KEY);
return (savedRequest == null) ? null : savedRequest.getFullRequestUrl();
}
protected void onPreAuthentication(HttpServletRequest request, protected void onPreAuthentication(HttpServletRequest request,
HttpServletResponse response) HttpServletResponse response)
throws AuthenticationException, IOException {} throws AuthenticationException, IOException {}
@ -428,9 +441,7 @@ public abstract class AbstractProcessingFilter implements Filter,
+ authResult + "'"); + authResult + "'");
} }
String targetUrl = (String) request.getSession() String targetUrl = obtainFullRequestUrl(request);
.getAttribute(ACEGI_SECURITY_TARGET_URL_KEY);
request.getSession().removeAttribute(ACEGI_SECURITY_TARGET_URL_KEY);
if (alwaysUseDefaultTargetUrl == true) { if (alwaysUseDefaultTargetUrl == true) {
targetUrl = null; targetUrl = null;

View File

@ -24,7 +24,7 @@ import org.acegisecurity.InsufficientAuthenticationException;
import org.acegisecurity.context.SecurityContextHolder; import org.acegisecurity.context.SecurityContextHolder;
import org.acegisecurity.intercept.web.FilterInvocation; import org.acegisecurity.ui.savedrequest.SavedRequest;
import org.acegisecurity.util.PortResolver; import org.acegisecurity.util.PortResolver;
import org.acegisecurity.util.PortResolverImpl; import org.acegisecurity.util.PortResolverImpl;
@ -250,34 +250,20 @@ public class ExceptionTranslationFilter implements Filter, InitializingBean {
AuthenticationException reason) throws ServletException, IOException { AuthenticationException reason) throws ServletException, IOException {
HttpServletRequest httpRequest = (HttpServletRequest) request; HttpServletRequest httpRequest = (HttpServletRequest) request;
int port = portResolver.getServerPort(httpRequest); SavedRequest savedRequest = new SavedRequest(httpRequest, portResolver);
boolean includePort = true;
if ("http".equals(httpRequest.getScheme().toLowerCase())
&& (port == 80)) {
includePort = false;
}
if ("https".equals(httpRequest.getScheme().toLowerCase())
&& (port == 443)) {
includePort = false;
}
String targetUrl = httpRequest.getScheme() + "://"
+ httpRequest.getServerName() + ((includePort) ? (":" + port) : "")
+ httpRequest.getContextPath()
+ new FilterInvocation(request, response, chain).getRequestUrl();
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug( logger.debug(
"Authentication entry point being called; target URL added to Session: " "Authentication entry point being called; SavedRequest added to Session: "
+ targetUrl); + savedRequest);
} }
if (createSessionAllowed) { if (createSessionAllowed) {
// Store the HTTP request itself. Used by AbstractProcessingFilter
// for redirection after successful authentication (SEC-29)
httpRequest.getSession() httpRequest.getSession()
.setAttribute(AbstractProcessingFilter.ACEGI_SECURITY_TARGET_URL_KEY, .setAttribute(AbstractProcessingFilter.ACEGI_SAVED_REQUEST_KEY,
targetUrl); savedRequest);
} }
// SEC-112: Clear the SecurityContextHolder's Authentication, as the // SEC-112: Clear the SecurityContextHolder's Authentication, as the

View File

@ -0,0 +1,152 @@
/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.acegisecurity.ui.savedrequest;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
/**
* <p>
* Adapter that wraps an <code>Enumeration</code> around a Java 2 collection
* <code>Iterator</code>.
* </p>
*
* <p>
* Constructors are provided to easily create such wrappers.
* </p>
*
* <p>
* This class is based on code in Apache Tomcat.
* </p>
*
* @author Craig McClanahan
* @author Andrey Grebnev
* @version $Id$
*/
public class Enumerator implements Enumeration {
//~ Instance fields ========================================================
/**
* The <code>Iterator</code> over which the <code>Enumeration</code>
* represented by this class actually operates.
*/
private Iterator iterator = null;
//~ Constructors ===========================================================
/**
* Return an Enumeration over the values of the specified Collection.
*
* @param collection Collection whose values should be enumerated
*/
public Enumerator(Collection collection) {
this(collection.iterator());
}
/**
* Return an Enumeration over the values of the specified Collection.
*
* @param collection Collection whose values should be enumerated
* @param clone true to clone iterator
*/
public Enumerator(Collection collection, boolean clone) {
this(collection.iterator(), clone);
}
/**
* Return an Enumeration over the values returned by the specified
* Iterator.
*
* @param iterator Iterator to be wrapped
*/
public Enumerator(Iterator iterator) {
super();
this.iterator = iterator;
}
/**
* Return an Enumeration over the values returned by the specified
* Iterator.
*
* @param iterator Iterator to be wrapped
* @param clone true to clone iterator
*/
public Enumerator(Iterator iterator, boolean clone) {
super();
if (!clone) {
this.iterator = iterator;
} else {
List list = new ArrayList();
while (iterator.hasNext()) {
list.add(iterator.next());
}
this.iterator = list.iterator();
}
}
/**
* Return an Enumeration over the values of the specified Map.
*
* @param map Map whose values should be enumerated
*/
public Enumerator(Map map) {
this(map.values().iterator());
}
/**
* Return an Enumeration over the values of the specified Map.
*
* @param map Map whose values should be enumerated
* @param clone true to clone iterator
*/
public Enumerator(Map map, boolean clone) {
this(map.values().iterator(), clone);
}
//~ Methods ================================================================
/**
* Tests if this enumeration contains more elements.
*
* @return <code>true</code> if and only if this enumeration object
* contains at least one more element to provide,
* <code>false</code> otherwise
*/
public boolean hasMoreElements() {
return (iterator.hasNext());
}
/**
* Returns the next element of this enumeration if this enumeration has at
* least one more element to provide.
*
* @return the next element of this enumeration
*
* @exception NoSuchElementException if no more elements exist
*/
public Object nextElement() throws NoSuchElementException {
return (iterator.next());
}
}

View File

@ -0,0 +1,234 @@
/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.acegisecurity.ui.savedrequest;
import java.text.DateFormat;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.HashMap;
import java.util.Locale;
import java.util.TimeZone;
/**
* <p>
* Utility class to generate HTTP dates.
* </p>
*
* <p>
* This class is based on code in Apache Tomcat.
* </p>
*
* @author Remy Maucherat
* @author Andrey Grebnev
* @version $Id$
*/
public class FastHttpDateFormat {
//~ Static fields/initializers =============================================
/** HTTP date format. */
protected static final SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz",
Locale.US);
/**
* The set of SimpleDateFormat formats to use in
* <code>getDateHeader()</code>.
*/
protected static final SimpleDateFormat[] formats = {new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz",
Locale.US), new SimpleDateFormat("EEEEEE, dd-MMM-yy HH:mm:ss zzz",
Locale.US), new SimpleDateFormat("EEE MMMM d HH:mm:ss yyyy",
Locale.US)};
/** GMT timezone - all HTTP dates are on GMT */
protected final static TimeZone gmtZone = TimeZone.getTimeZone("GMT");
static {
format.setTimeZone(gmtZone);
formats[0].setTimeZone(gmtZone);
formats[1].setTimeZone(gmtZone);
formats[2].setTimeZone(gmtZone);
}
/** Instant on which the currentDate object was generated. */
protected static long currentDateGenerated = 0L;
/** Current formatted date. */
protected static String currentDate = null;
/** Formatter cache. */
protected static final HashMap formatCache = new HashMap();
/** Parser cache. */
protected static final HashMap parseCache = new HashMap();
//~ Methods ================================================================
/**
* Formats a specified date to HTTP format. If local format is not
* <code>null</code>, it's used instead.
*
* @param value Date value to format
* @param threadLocalformat The format to use (or <code>null</code> -- then
* HTTP format will be used)
*
* @return Formatted date
*/
public static final String formatDate(long value,
DateFormat threadLocalformat) {
String cachedDate = null;
Long longValue = new Long(value);
try {
cachedDate = (String) formatCache.get(longValue);
} catch (Exception e) {}
if (cachedDate != null) {
return cachedDate;
}
String newDate = null;
Date dateValue = new Date(value);
if (threadLocalformat != null) {
newDate = threadLocalformat.format(dateValue);
synchronized (formatCache) {
updateCache(formatCache, longValue, newDate);
}
} else {
synchronized (formatCache) {
newDate = format.format(dateValue);
updateCache(formatCache, longValue, newDate);
}
}
return newDate;
}
/**
* Gets the current date in HTTP format.
*
* @return Current date in HTTP format
*/
public static final String getCurrentDate() {
long now = System.currentTimeMillis();
if ((now - currentDateGenerated) > 1000) {
synchronized (format) {
if ((now - currentDateGenerated) > 1000) {
currentDateGenerated = now;
currentDate = format.format(new Date(now));
}
}
}
return currentDate;
}
/**
* Parses date with given formatters.
*
* @param value The string to parse
* @param formats Array of formats to use
*
* @return Parsed date (or <code>null</code> if no formatter mached)
*/
private static final Long internalParseDate(String value,
DateFormat[] formats) {
Date date = null;
for (int i = 0; (date == null) && (i < formats.length); i++) {
try {
date = formats[i].parse(value);
} catch (ParseException e) {
;
}
}
if (date == null) {
return null;
}
return new Long(date.getTime());
}
/**
* Tries to parse the given date as an HTTP date. If local format list is
* not <code>null</code>, it's used instead.
*
* @param value The string to parse
* @param threadLocalformats Array of formats to use for parsing. If
* <code>null</code>, HTTP formats are used.
*
* @return Parsed date (or -1 if error occured)
*/
public static final long parseDate(String value,
DateFormat[] threadLocalformats) {
Long cachedDate = null;
try {
cachedDate = (Long) parseCache.get(value);
} catch (Exception e) {}
if (cachedDate != null) {
return cachedDate.longValue();
}
Long date = null;
if (threadLocalformats != null) {
date = internalParseDate(value, threadLocalformats);
synchronized (parseCache) {
updateCache(parseCache, value, date);
}
} else {
synchronized (parseCache) {
date = internalParseDate(value, formats);
updateCache(parseCache, value, date);
}
}
if (date == null) {
return (-1L);
} else {
return date.longValue();
}
}
/**
* Updates cache.
*
* @param cache Cache to be updated
* @param key Key to be updated
* @param value New value
*/
private static final void updateCache(HashMap cache, Object key,
Object value) {
if (value == null) {
return;
}
if (cache.size() > 1000) {
cache.clear();
}
cache.put(key, value);
}
}

View File

@ -0,0 +1,362 @@
/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.acegisecurity.ui.savedrequest;
import org.acegisecurity.util.PortResolver;
import org.acegisecurity.util.UrlUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.util.Assert;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
/**
* Represents central information from a <code>HttpServletRequest</code>.
*
* <p>
* This class is used by {@link org.acegisecurity.ui.AbstractProcessingFilter}
* and {@link org.acegisecurity.wrapper.SavedRequestAwareWrapper} to reproduce
* the request after successful authentication. An instance of this class is
* stored at the time of an authentication exception by {@link
* org.acegisecurity.ui.ExceptionTranslationFilter}.
* </p>
*
* <p>
* <em>IMPLEMENTATION NOTE</em>: It is assumed that this object is accessed
* only from the context of a single thread, so no synchronization around
* internal collection classes is performed.
* </p>
*
* <p>
* This class is based on code in Apache Tomcat.
* </p>
*
* @author Craig McClanahan
* @author Andrey Grebnev
* @author Ben Alex
* @version $Id$
*/
public class SavedRequest {
//~ Static fields/initializers =============================================
protected static final Log logger = LogFactory.getLog(SavedRequest.class);
//~ Instance fields ========================================================
private ArrayList cookies = new ArrayList();
private ArrayList locales = new ArrayList();
private HashMap headers = new HashMap();
private HashMap parameters = new HashMap();
private String contextPath;
private String method;
private String pathInfo;
private String queryString;
private String requestURI;
private String requestURL;
private String scheme;
private String serverName;
private String servletPath;
private int serverPort;
//~ Constructors ===========================================================
public SavedRequest(HttpServletRequest request, PortResolver portResolver) {
Assert.notNull(request, "Request required");
Assert.notNull(portResolver, "PortResolver required");
// Cookies
Cookie[] cookies = request.getCookies();
if (cookies != null) {
for (int i = 0; i < cookies.length; i++) {
this.addCookie(cookies[i]);
}
}
// Headers
Enumeration names = request.getHeaderNames();
while (names.hasMoreElements()) {
String name = (String) names.nextElement();
Enumeration values = request.getHeaders(name);
while (values.hasMoreElements()) {
String value = (String) values.nextElement();
this.addHeader(name, value);
}
}
// Locales
Enumeration locales = request.getLocales();
while (locales.hasMoreElements()) {
Locale locale = (Locale) locales.nextElement();
this.addLocale(locale);
}
// Parameters
Map parameters = request.getParameterMap();
Iterator paramNames = parameters.keySet().iterator();
while (paramNames.hasNext()) {
String paramName = (String) paramNames.next();
String[] paramValues = (String[]) parameters.get(paramName);
this.addParameter(paramName, paramValues);
}
// Primitives
this.method = request.getMethod();
this.pathInfo = request.getPathInfo();
this.queryString = request.getQueryString();
this.requestURI = request.getRequestURI();
this.serverPort = portResolver.getServerPort(request);
this.requestURL = request.getRequestURL().toString();
this.scheme = request.getScheme();
this.serverName = request.getServerName();
this.contextPath = request.getContextPath();
this.servletPath = request.getServletPath();
}
//~ Methods ================================================================
private void addCookie(Cookie cookie) {
cookies.add(cookie);
}
private void addHeader(String name, String value) {
ArrayList values = (ArrayList) headers.get(name);
if (values == null) {
values = new ArrayList();
headers.put(name, values);
}
values.add(value);
}
private void addLocale(Locale locale) {
locales.add(locale);
}
private void addParameter(String name, String[] values) {
parameters.put(name, values);
}
/**
* Determines if the current request matches the <code>SavedRequest</code>.
* All URL arguments are considered, but <em>not</em> method (POST/GET),
* cookies, locales, headers or parameters.
*
* @param request DOCUMENT ME!
* @param portResolver DOCUMENT ME!
*
* @return DOCUMENT ME!
*/
public boolean doesRequestMatch(HttpServletRequest request,
PortResolver portResolver) {
Assert.notNull(request, "Request required");
Assert.notNull(portResolver, "PortResolver required");
if (!propertyEquals("pathInfo", this.pathInfo, request.getPathInfo())) {
return false;
}
if (!propertyEquals("queryString", this.queryString,
request.getQueryString())) {
return false;
}
if (!propertyEquals("requestURI", this.requestURI,
request.getRequestURI())) {
return false;
}
if (!propertyEquals("serverPort", new Integer(this.serverPort),
new Integer(portResolver.getServerPort(request)))) {
return false;
}
if (!propertyEquals("requestURL", this.requestURL,
request.getRequestURL().toString())) {
return false;
}
if (!propertyEquals("scheme", this.scheme, request.getScheme())) {
return false;
}
if (!propertyEquals("serverName", this.serverName,
request.getServerName())) {
return false;
}
if (!propertyEquals("contextPath", this.contextPath,
request.getContextPath())) {
return false;
}
if (!propertyEquals("servletPath", this.servletPath,
request.getServletPath())) {
return false;
}
return true;
}
public String getContextPath() {
return contextPath;
}
public List getCookies() {
return cookies;
}
/**
* Indicates the URL that the user agent used for this request.
*
* @return the full URL of this request
*/
public String getFullRequestUrl() {
return UrlUtils.getFullRequestUrl(this);
}
public Iterator getHeaderNames() {
return (headers.keySet().iterator());
}
public Iterator getHeaderValues(String name) {
ArrayList values = (ArrayList) headers.get(name);
if (values == null) {
return ((new ArrayList()).iterator());
} else {
return (values.iterator());
}
}
public Iterator getLocales() {
return (locales.iterator());
}
public String getMethod() {
return (this.method);
}
public Map getParameterMap() {
return parameters;
}
public Iterator getParameterNames() {
return (parameters.keySet().iterator());
}
public String[] getParameterValues(String name) {
return ((String[]) parameters.get(name));
}
public String getPathInfo() {
return pathInfo;
}
public String getQueryString() {
return (this.queryString);
}
public String getRequestURI() {
return (this.requestURI);
}
public String getRequestURL() {
return requestURL;
}
/**
* Obtains the web application-specific fragment of the URL.
*
* @return the URL, excluding any server name, context path or servlet path
*/
public String getRequestUrl() {
return UrlUtils.getRequestUrl(this);
}
public String getScheme() {
return scheme;
}
public String getServerName() {
return serverName;
}
public int getServerPort() {
return serverPort;
}
public String getServletPath() {
return servletPath;
}
private boolean propertyEquals(String log, Object arg1, Object arg2) {
if ((arg1 == null) && (arg2 == null)) {
if (logger.isDebugEnabled()) {
logger.debug(log + ": both null (property equals)");
}
return true;
}
if (((arg1 == null) && (arg2 != null))
|| ((arg1 != null) && (arg2 == null))) {
if (logger.isDebugEnabled()) {
logger.debug(log + ": arg1=" + arg1 + "; arg2=" + arg2
+ " (property not equals)");
}
return false;
}
if (arg1.equals(arg2)) {
if (logger.isDebugEnabled()) {
logger.debug(log + ": arg1=" + arg1 + "; arg2=" + arg2
+ " (property equals)");
}
return true;
} else {
if (logger.isDebugEnabled()) {
logger.debug(log + ": arg1=" + arg1 + "; arg2=" + arg2
+ " (property not equals)");
}
return false;
}
}
public String toString() {
return "SavedRequest[" + getFullRequestUrl() + "]";
}
}

View File

@ -0,0 +1,6 @@
<html>
<body>
Stores a <code>HttpServletRequest</code> so that it can subsequently be emulated by the
<code>SavedRequestAwareWrapper</code>.
</body>
</html>

View File

@ -0,0 +1,130 @@
/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.acegisecurity.util;
import org.acegisecurity.intercept.web.FilterInvocation;
import org.acegisecurity.ui.savedrequest.SavedRequest;
import javax.servlet.http.HttpServletRequest;
/**
* Provides static methods for composing URLs.
*
* <p>
* Placed into a separate class for visibility, so that changes to URL
* formatting conventions will affect all users.
* </p>
*
* @author Ben Alex
* @version $Id$
*/
public class UrlUtils {
//~ Methods ================================================================
/**
* Obtains the full URL the client used to make the request.
*
* <p>
* Note that the server port will not be shown if it is the default server
* port for HTTP or HTTPS (ie 80 and 443 respectively).
* </p>
*
* @param scheme DOCUMENT ME!
* @param serverName DOCUMENT ME!
* @param serverPort DOCUMENT ME!
* @param contextPath DOCUMENT ME!
* @param requestUrl DOCUMENT ME!
* @param servletPath DOCUMENT ME!
* @param requestURI DOCUMENT ME!
* @param pathInfo DOCUMENT ME!
* @param queryString DOCUMENT ME!
*
* @return the full URL
*/
private static String buildFullRequestUrl(String scheme, String serverName,
int serverPort, String contextPath, String requestUrl,
String servletPath, String requestURI, String pathInfo,
String queryString) {
boolean includePort = true;
if ("http".equals(scheme.toLowerCase()) && (serverPort == 80)) {
includePort = false;
}
if ("https".equals(scheme.toLowerCase()) && (serverPort == 443)) {
includePort = false;
}
return scheme + "://" + serverName
+ ((includePort) ? (":" + serverPort) : "") + contextPath
+ buildRequestUrl(servletPath, requestURI, contextPath, pathInfo,
queryString);
}
/**
* Obtains the web application-specific fragment of the URL.
*
* @param servletPath DOCUMENT ME!
* @param requestURI DOCUMENT ME!
* @param contextPath DOCUMENT ME!
* @param pathInfo DOCUMENT ME!
* @param queryString DOCUMENT ME!
*
* @return the URL, excluding any server name, context path or servlet path
*/
private static String buildRequestUrl(String servletPath,
String requestURI, String contextPath, String pathInfo,
String queryString) {
String uri = servletPath;
if (uri == null) {
uri = requestURI;
uri = uri.substring(contextPath.length());
}
return uri + ((pathInfo == null) ? "" : pathInfo)
+ ((queryString == null) ? "" : ("?" + queryString));
}
public static String getFullRequestUrl(FilterInvocation fi) {
HttpServletRequest r = fi.getHttpRequest();
return buildFullRequestUrl(r.getScheme(), r.getServerName(),
r.getServerPort(), r.getContextPath(),
r.getRequestURL().toString(), r.getServletPath(),
r.getRequestURI(), r.getPathInfo(), r.getQueryString());
}
public static String getFullRequestUrl(SavedRequest sr) {
return buildFullRequestUrl(sr.getScheme(), sr.getServerName(),
sr.getServerPort(), sr.getContextPath(), sr.getRequestURL(),
sr.getServletPath(), sr.getRequestURI(), sr.getPathInfo(),
sr.getQueryString());
}
public static String getRequestUrl(FilterInvocation fi) {
HttpServletRequest r = fi.getHttpRequest();
return buildRequestUrl(r.getServletPath(), r.getRequestURI(),
r.getContextPath(), r.getPathInfo(), r.getQueryString());
}
public static String getRequestUrl(SavedRequest sr) {
return buildRequestUrl(sr.getServletPath(), sr.getRequestURI(),
sr.getContextPath(), sr.getPathInfo(), sr.getQueryString());
}
}

View File

@ -0,0 +1,409 @@
/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.acegisecurity.wrapper;
import org.acegisecurity.ui.AbstractProcessingFilter;
import org.acegisecurity.ui.savedrequest.Enumerator;
import org.acegisecurity.ui.savedrequest.FastHttpDateFormat;
import org.acegisecurity.ui.savedrequest.SavedRequest;
import org.acegisecurity.util.PortResolver;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.TimeZone;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpSession;
/**
* Provides request parameters, headers and cookies from either an original
* request or a saved request.
*
* <p>
* Note that not all request parameters in the original request are emulated by
* this wrapper. Nevertheless, the important data from the original request is
* emulated and this should prove adequate for most purposes (in particular
* standard HTTP GET and POST operations).
* </p>
*
* <p>
* Added into a request by {@link
* org.acegisecurity.wrapper.SecurityContextHolderAwareRequestFilter}.
* </p>
*
* @author Andrey Grebnev
* @author Ben Alex
* @version $Id$
*/
public class SavedRequestAwareWrapper
extends SecurityContextHolderAwareRequestWrapper {
//~ Static fields/initializers =============================================
protected static final Log logger = LogFactory.getLog(SavedRequestAwareWrapper.class);
protected static final TimeZone GMT_ZONE = TimeZone.getTimeZone("GMT");
/** The default Locale if none are specified. */
protected static Locale defaultLocale = Locale.getDefault();
//~ Instance fields ========================================================
protected SavedRequest savedRequest = null;
/**
* The set of SimpleDateFormat formats to use in getDateHeader(). Notice
* that because SimpleDateFormat is not thread-safe, we can't declare
* formats[] as a static variable.
*/
protected SimpleDateFormat[] formats = new SimpleDateFormat[3];
//~ Constructors ===========================================================
public SavedRequestAwareWrapper(HttpServletRequest request,
PortResolver portResolver) {
super(request);
HttpSession session = request.getSession(false);
if (session == null) {
if (logger.isDebugEnabled()) {
logger.debug(
"Wrapper not replaced; no session available for SavedRequest extraction");
}
return;
}
SavedRequest saved = (SavedRequest) session.getAttribute(AbstractProcessingFilter.ACEGI_SAVED_REQUEST_KEY);
if ((saved != null) && saved.doesRequestMatch(request, portResolver)) {
if (logger.isDebugEnabled()) {
logger.debug("Wrapper replaced; SavedRequest was: " + saved);
}
savedRequest = saved;
session.removeAttribute(AbstractProcessingFilter.ACEGI_SAVED_REQUEST_KEY);
formats[0] = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz",
Locale.US);
formats[1] = new SimpleDateFormat("EEEEEE, dd-MMM-yy HH:mm:ss zzz",
Locale.US);
formats[2] = new SimpleDateFormat("EEE MMMM d HH:mm:ss yyyy",
Locale.US);
formats[0].setTimeZone(GMT_ZONE);
formats[1].setTimeZone(GMT_ZONE);
formats[2].setTimeZone(GMT_ZONE);
} else {
if (logger.isDebugEnabled()) {
logger.debug("Wrapper not replaced; SavedRequest was: " + saved);
}
}
}
//~ Methods ================================================================
/**
* The default behavior of this method is to return getCookies() on the
* wrapped request object.
*
* @return DOCUMENT ME!
*/
public Cookie[] getCookies() {
if (savedRequest == null) {
return super.getCookies();
} else {
List cookies = savedRequest.getCookies();
return (Cookie[]) cookies.toArray(new Cookie[cookies.size()]);
}
}
/**
* The default behavior of this method is to return getDateHeader(String
* name) on the wrapped request object.
*
* @param name DOCUMENT ME!
*
* @return DOCUMENT ME!
*
* @throws IllegalArgumentException DOCUMENT ME!
*/
public long getDateHeader(String name) {
if (savedRequest == null) {
return super.getDateHeader(name);
} else {
String value = getHeader(name);
if (value == null) {
return (-1L);
}
// Attempt to convert the date header in a variety of formats
long result = FastHttpDateFormat.parseDate(value, formats);
if (result != (-1L)) {
return result;
}
throw new IllegalArgumentException(value);
}
}
/**
* The default behavior of this method is to return getHeader(String name)
* on the wrapped request object.
*
* @param name DOCUMENT ME!
*
* @return DOCUMENT ME!
*/
public String getHeader(String name) {
if (savedRequest == null) {
return super.getHeader(name);
} else {
String header = null;
Iterator iterator = savedRequest.getHeaderValues(name);
while (iterator.hasNext()) {
header = (String) iterator.next();
break;
}
return header;
}
}
/**
* The default behavior of this method is to return getHeaderNames() on the
* wrapped request object.
*
* @return DOCUMENT ME!
*/
public Enumeration getHeaderNames() {
if (savedRequest == null) {
return super.getHeaderNames();
} else {
return new Enumerator(savedRequest.getHeaderNames());
}
}
/**
* The default behavior of this method is to return getHeaders(String name)
* on the wrapped request object.
*
* @param name DOCUMENT ME!
*
* @return DOCUMENT ME!
*/
public Enumeration getHeaders(String name) {
if (savedRequest == null) {
return super.getHeaders(name);
} else {
return new Enumerator(savedRequest.getHeaderValues(name));
}
}
/**
* The default behavior of this method is to return getIntHeader(String
* name) on the wrapped request object.
*
* @param name DOCUMENT ME!
*
* @return DOCUMENT ME!
*/
public int getIntHeader(String name) {
if (savedRequest == null) {
return super.getIntHeader(name);
} else {
String value = getHeader(name);
if (value == null) {
return (-1);
} else {
return (Integer.parseInt(value));
}
}
}
/**
* The default behavior of this method is to return getLocale() on the
* wrapped request object.
*
* @return DOCUMENT ME!
*/
public Locale getLocale() {
if (savedRequest == null) {
return super.getLocale();
} else {
Locale locale = null;
Iterator iterator = savedRequest.getLocales();
while (iterator.hasNext()) {
locale = (Locale) iterator.next();
break;
}
if (locale == null) {
return defaultLocale;
} else {
return locale;
}
}
}
/**
* The default behavior of this method is to return getLocales() on the
* wrapped request object.
*
* @return DOCUMENT ME!
*/
public Enumeration getLocales() {
if (savedRequest == null) {
return super.getLocales();
} else {
Iterator iterator = savedRequest.getLocales();
if (iterator.hasNext()) {
return new Enumerator(iterator);
} else {
ArrayList results = new ArrayList();
results.add(defaultLocale);
return new Enumerator(results.iterator());
}
}
}
/**
* The default behavior of this method is to return getMethod() on the
* wrapped request object.
*
* @return DOCUMENT ME!
*/
public String getMethod() {
if (savedRequest == null) {
return super.getMethod();
} else {
return savedRequest.getMethod();
}
}
/**
* The default behavior of this method is to return getParameter(String
* name) on the wrapped request object.
*
* @param name DOCUMENT ME!
*
* @return DOCUMENT ME!
*/
public String getParameter(String name) {
/*
if (savedRequest == null) {
return super.getParameter(name);
} else {
String value = null;
String[] values = savedRequest.getParameterValues(name);
if (values == null)
return null;
for (int i = 0; i < values.length; i++) {
value = values[i];
break;
}
return value;
}
*/
//we do not get value from super.getParameter because there is a bug in Jetty servlet-container
String value = null;
String[] values = null;
if (savedRequest == null) {
values = super.getParameterValues(name);
} else {
values = savedRequest.getParameterValues(name);
}
if (values == null) {
return null;
}
for (int i = 0; i < values.length; i++) {
value = values[i];
break;
}
return value;
}
/**
* The default behavior of this method is to return getParameterMap() on
* the wrapped request object.
*
* @return DOCUMENT ME!
*/
public Map getParameterMap() {
if (savedRequest == null) {
return super.getParameterMap();
} else {
return savedRequest.getParameterMap();
}
}
/**
* The default behavior of this method is to return getParameterNames() on
* the wrapped request object.
*
* @return DOCUMENT ME!
*/
public Enumeration getParameterNames() {
if (savedRequest == null) {
return super.getParameterNames();
} else {
return new Enumerator(savedRequest.getParameterNames());
}
}
/**
* The default behavior of this method is to return
* getParameterValues(String name) on the wrapped request object.
*
* @param name DOCUMENT ME!
*
* @return DOCUMENT ME!
*/
public String[] getParameterValues(String name) {
if (savedRequest == null) {
return super.getParameterValues(name);
} else {
return savedRequest.getParameterValues(name);
}
}
}

View File

@ -1,4 +1,4 @@
/* Copyright 2004, 2005 Acegi Technology Pty Limited /* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,8 +15,16 @@
package org.acegisecurity.wrapper; package org.acegisecurity.wrapper;
import org.acegisecurity.util.PortResolver;
import org.acegisecurity.util.PortResolverImpl;
import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;
import java.io.IOException; import java.io.IOException;
import java.lang.reflect.Constructor;
import javax.servlet.Filter; import javax.servlet.Filter;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
import javax.servlet.FilterConfig; import javax.servlet.FilterConfig;
@ -27,13 +35,38 @@ import javax.servlet.http.HttpServletRequest;
/** /**
* A <code>Filter</code> which populates the <code>ServletRequest</code> with * A <code>Filter</code> which populates the <code>ServletRequest</code> with a
* an {@link SecurityContextHolderAwareRequestWrapper}. * new request wrapper.
*
* <p>
* Several request wrappers are included with the framework. The simplest
* version is {@link SecurityContextHolderAwareRequestWrapper}. A more complex
* and powerful request wrapper is {@link
* org.acegisecurity.wrapper.SavedRequestAwareWrapper}. The latter is also the
* default.
* </p>
*
* <p>
* To modify the wrapper used, call {@link #setWrapperClass(Class)}.
* </p>
*
* <p>
* Any request wrapper configured for instantiation by this class must provide
* a public constructor that accepts two arguments, being a
* <code>HttpServletRequest</code> and a <code>PortResolver</code>.
* </p>
* *
* @author Orlando Garcia Carmona * @author Orlando Garcia Carmona
* @author Ben Alex
* @version $Id$ * @version $Id$
*/ */
public class SecurityContextHolderAwareRequestFilter implements Filter { public class SecurityContextHolderAwareRequestFilter implements Filter {
//~ Instance fields ========================================================
private Class wrapperClass = SavedRequestAwareWrapper.class;
private Constructor constructor;
private PortResolver portResolver = new PortResolverImpl();
//~ Methods ================================================================ //~ Methods ================================================================
public void destroy() {} public void destroy() {}
@ -43,12 +76,36 @@ public class SecurityContextHolderAwareRequestFilter implements Filter {
throws IOException, ServletException { throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) servletRequest; HttpServletRequest request = (HttpServletRequest) servletRequest;
if (!(request instanceof SecurityContextHolderAwareRequestWrapper)) { if (!wrapperClass.isAssignableFrom(request.getClass())) {
request = new SecurityContextHolderAwareRequestWrapper(request); if (constructor == null) {
try {
constructor = wrapperClass.getConstructor(new Class[] {HttpServletRequest.class, PortResolver.class});
} catch (Exception ex) {
ReflectionUtils.handleReflectionException(ex);
}
}
try {
request = (HttpServletRequest) constructor.newInstance(new Object[] {request, portResolver});
} catch (Exception ex) {
ReflectionUtils.handleReflectionException(ex);
}
} }
filterChain.doFilter(request, servletResponse); filterChain.doFilter(request, servletResponse);
} }
public void init(FilterConfig filterConfig) throws ServletException {} public void init(FilterConfig filterConfig) throws ServletException {}
public void setPortResolver(PortResolver portResolver) {
Assert.notNull(portResolver, "PortResolver required");
this.portResolver = portResolver;
}
public void setWrapperClass(Class wrapperClass) {
Assert.notNull(wrapperClass, "WrapperClass required");
Assert.isTrue(HttpServletRequest.class.isAssignableFrom(wrapperClass),
"Wrapper must be a HttpServletRequest");
this.wrapperClass = wrapperClass;
}
} }

View File

@ -1,4 +1,4 @@
/* Copyright 2004, 2005 Acegi Technology Pty Limited /* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -17,12 +17,14 @@ package org.acegisecurity.intercept.web;
import org.acegisecurity.MockFilterChain; import org.acegisecurity.MockFilterChain;
import javax.servlet.ServletRequest; import org.jmock.MockObjectTestCase;
import javax.servlet.ServletResponse;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.jmock.MockObjectTestCase;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
/** /**
* Tests {@link FilterInvocation}. * Tests {@link FilterInvocation}.
@ -44,14 +46,14 @@ public class FilterInvocationTests extends MockObjectTestCase {
//~ Methods ================================================================ //~ Methods ================================================================
public final void setUp() throws Exception {
super.setUp();
}
public static void main(String[] args) { public static void main(String[] args) {
junit.textui.TestRunner.run(FilterInvocationTests.class); junit.textui.TestRunner.run(FilterInvocationTests.class);
} }
public final void setUp() throws Exception {
super.setUp();
}
public void testGettersAndStringMethods() { public void testGettersAndStringMethods() {
MockHttpServletRequest request = new MockHttpServletRequest(null, null); MockHttpServletRequest request = new MockHttpServletRequest(null, null);
request.setServletPath("/HelloWorld"); request.setServletPath("/HelloWorld");
@ -73,7 +75,7 @@ public class FilterInvocationTests extends MockObjectTestCase {
assertEquals("/HelloWorld/some/more/segments.html", fi.getRequestUrl()); assertEquals("/HelloWorld/some/more/segments.html", fi.getRequestUrl());
assertEquals("FilterInvocation: URL: /HelloWorld/some/more/segments.html", assertEquals("FilterInvocation: URL: /HelloWorld/some/more/segments.html",
fi.toString()); fi.toString());
assertEquals("http://www.example.com:80/mycontext/HelloWorld/some/more/segments.html", assertEquals("http://www.example.com/mycontext/HelloWorld/some/more/segments.html",
fi.getFullRequestUrl()); fi.getFullRequestUrl());
} }
@ -81,7 +83,7 @@ public class FilterInvocationTests extends MockObjectTestCase {
Class clazz = FilterInvocation.class; Class clazz = FilterInvocation.class;
try { try {
clazz.getDeclaredConstructor((Class[])null); clazz.getDeclaredConstructor((Class[]) null);
fail("Should have thrown NoSuchMethodException"); fail("Should have thrown NoSuchMethodException");
} catch (NoSuchMethodException expected) { } catch (NoSuchMethodException expected) {
assertTrue(true); assertTrue(true);
@ -125,7 +127,7 @@ public class FilterInvocationTests extends MockObjectTestCase {
} }
public void testRejectsServletRequestWhichIsNotHttpServletRequest() { public void testRejectsServletRequestWhichIsNotHttpServletRequest() {
ServletRequest request = (ServletRequest)newDummy(ServletRequest.class); ServletRequest request = (ServletRequest) newDummy(ServletRequest.class);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
MockFilterChain chain = new MockFilterChain(); MockFilterChain chain = new MockFilterChain();
@ -167,7 +169,7 @@ public class FilterInvocationTests extends MockObjectTestCase {
FilterInvocation fi = new FilterInvocation(request, response, chain); FilterInvocation fi = new FilterInvocation(request, response, chain);
assertEquals("/HelloWorld?foo=bar", fi.getRequestUrl()); assertEquals("/HelloWorld?foo=bar", fi.getRequestUrl());
assertEquals("FilterInvocation: URL: /HelloWorld?foo=bar", fi.toString()); assertEquals("FilterInvocation: URL: /HelloWorld?foo=bar", fi.toString());
assertEquals("http://www.example.com:80/mycontext/HelloWorld?foo=bar", assertEquals("http://www.example.com/mycontext/HelloWorld?foo=bar",
fi.getFullRequestUrl()); fi.getFullRequestUrl());
} }
@ -185,7 +187,7 @@ public class FilterInvocationTests extends MockObjectTestCase {
FilterInvocation fi = new FilterInvocation(request, response, chain); FilterInvocation fi = new FilterInvocation(request, response, chain);
assertEquals("/HelloWorld", fi.getRequestUrl()); assertEquals("/HelloWorld", fi.getRequestUrl());
assertEquals("FilterInvocation: URL: /HelloWorld", fi.toString()); assertEquals("FilterInvocation: URL: /HelloWorld", fi.toString());
assertEquals("http://www.example.com:80/mycontext/HelloWorld", assertEquals("http://www.example.com/mycontext/HelloWorld",
fi.getFullRequestUrl()); fi.getFullRequestUrl());
} }
} }

View File

@ -30,6 +30,9 @@ import org.acegisecurity.context.SecurityContextHolder;
import org.acegisecurity.providers.UsernamePasswordAuthenticationToken; import org.acegisecurity.providers.UsernamePasswordAuthenticationToken;
import org.acegisecurity.ui.rememberme.TokenBasedRememberMeServices; import org.acegisecurity.ui.rememberme.TokenBasedRememberMeServices;
import org.acegisecurity.ui.savedrequest.SavedRequest;
import org.acegisecurity.util.PortResolverImpl;
import org.springframework.mock.web.MockFilterConfig; import org.springframework.mock.web.MockFilterConfig;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
@ -91,6 +94,16 @@ public class AbstractProcessingFilterTests extends TestCase {
junit.textui.TestRunner.run(AbstractProcessingFilterTests.class); junit.textui.TestRunner.run(AbstractProcessingFilterTests.class);
} }
private SavedRequest makeSavedRequestForUrl() {
MockHttpServletRequest request = createMockRequest();
request.setServletPath("/some_protected_file.html");
request.setScheme("http");
request.setServerName("www.example.com");
request.setRequestURI("/mycontext/some_protected_file.html");
return new SavedRequest(request, new PortResolverImpl());
}
protected void setUp() throws Exception { protected void setUp() throws Exception {
super.setUp(); super.setUp();
SecurityContextHolder.clearContext(); SecurityContextHolder.clearContext();
@ -399,8 +412,8 @@ public class AbstractProcessingFilterTests extends TestCase {
// Setup our HTTP request // Setup our HTTP request
MockHttpServletRequest request = createMockRequest(); MockHttpServletRequest request = createMockRequest();
request.getSession() request.getSession()
.setAttribute(AbstractProcessingFilter.ACEGI_SECURITY_TARGET_URL_KEY, .setAttribute(AbstractProcessingFilter.ACEGI_SAVED_REQUEST_KEY,
"/my-destination"); makeSavedRequestForUrl());
// Setup our filter configuration // Setup our filter configuration
MockFilterConfig config = new MockFilterConfig(null); MockFilterConfig config = new MockFilterConfig(null);
@ -429,8 +442,8 @@ public class AbstractProcessingFilterTests extends TestCase {
// Setup our HTTP request // Setup our HTTP request
MockHttpServletRequest request = createMockRequest(); MockHttpServletRequest request = createMockRequest();
request.getSession() request.getSession()
.setAttribute(AbstractProcessingFilter.ACEGI_SECURITY_TARGET_URL_KEY, .setAttribute(AbstractProcessingFilter.ACEGI_SAVED_REQUEST_KEY,
"/my-destination"); makeSavedRequestForUrl());
// Setup our filter configuration // Setup our filter configuration
MockFilterConfig config = new MockFilterConfig(null); MockFilterConfig config = new MockFilterConfig(null);
@ -446,7 +459,8 @@ public class AbstractProcessingFilterTests extends TestCase {
// Test // Test
executeFilterInContainerSimulator(config, filter, request, response, executeFilterInContainerSimulator(config, filter, request, response,
chain); chain);
assertEquals("/my-destination", response.getRedirectedUrl()); assertEquals(makeSavedRequestForUrl().getFullRequestUrl(),
response.getRedirectedUrl());
assertNotNull(SecurityContextHolder.getContext().getAuthentication()); assertNotNull(SecurityContextHolder.getContext().getAuthentication());
} }

View File

@ -28,8 +28,6 @@ import org.acegisecurity.context.SecurityContextHolder;
import org.acegisecurity.providers.anonymous.AnonymousAuthenticationToken; import org.acegisecurity.providers.anonymous.AnonymousAuthenticationToken;
import org.acegisecurity.ui.webapp.AuthenticationProcessingFilter;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
@ -101,8 +99,7 @@ public class ExceptionTranslationFilterTests extends TestCase {
filter.doFilter(request, response, chain); filter.doFilter(request, response, chain);
assertEquals("/mycontext/login.jsp", response.getRedirectedUrl()); assertEquals("/mycontext/login.jsp", response.getRedirectedUrl());
assertEquals("http://www.example.com/mycontext/secure/page.html", assertEquals("http://www.example.com/mycontext/secure/page.html",
request.getSession() AbstractProcessingFilter.obtainFullRequestUrl(request));
.getAttribute(AuthenticationProcessingFilter.ACEGI_SECURITY_TARGET_URL_KEY));
} }
public void testAccessDeniedWhenNonAnonymous() throws Exception { public void testAccessDeniedWhenNonAnonymous() throws Exception {
@ -192,8 +189,7 @@ public class ExceptionTranslationFilterTests extends TestCase {
filter.doFilter(request, response, chain); filter.doFilter(request, response, chain);
assertEquals("/mycontext/login.jsp", response.getRedirectedUrl()); assertEquals("/mycontext/login.jsp", response.getRedirectedUrl());
assertEquals("http://www.example.com/mycontext/secure/page.html", assertEquals("http://www.example.com/mycontext/secure/page.html",
request.getSession() AbstractProcessingFilter.obtainFullRequestUrl(request));
.getAttribute(AuthenticationProcessingFilter.ACEGI_SECURITY_TARGET_URL_KEY));
} }
public void testRedirectedToLoginFormAndSessionShowsOriginalTargetWithExoticPortWhenAuthenticationException() public void testRedirectedToLoginFormAndSessionShowsOriginalTargetWithExoticPortWhenAuthenticationException()
@ -221,8 +217,7 @@ public class ExceptionTranslationFilterTests extends TestCase {
filter.doFilter(request, response, chain); filter.doFilter(request, response, chain);
assertEquals("/mycontext/login.jsp", response.getRedirectedUrl()); assertEquals("/mycontext/login.jsp", response.getRedirectedUrl());
assertEquals("http://www.example.com:8080/mycontext/secure/page.html", assertEquals("http://www.example.com:8080/mycontext/secure/page.html",
request.getSession() AbstractProcessingFilter.obtainFullRequestUrl(request));
.getAttribute(AuthenticationProcessingFilter.ACEGI_SECURITY_TARGET_URL_KEY));
} }
public void testStartupDetectsMissingAuthenticationEntryPoint() public void testStartupDetectsMissingAuthenticationEntryPoint()

View File

@ -1,4 +1,4 @@
/* Copyright 2004, 2005 Acegi Technology Pty Limited /* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -19,6 +19,8 @@ import junit.framework.TestCase;
import org.acegisecurity.MockFilterConfig; import org.acegisecurity.MockFilterConfig;
import org.springframework.mock.web.MockHttpServletRequest;
import java.io.IOException; import java.io.IOException;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
@ -26,8 +28,6 @@ import javax.servlet.ServletException;
import javax.servlet.ServletRequest; import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse; import javax.servlet.ServletResponse;
import org.springframework.mock.web.MockHttpServletRequest;
/** /**
* Tests {@link SecurityContextHolderAwareRequestFilter}. * Tests {@link SecurityContextHolderAwareRequestFilter}.
@ -48,23 +48,23 @@ public class SecurityContextHolderAwareRequestFilterTests extends TestCase {
//~ Methods ================================================================ //~ Methods ================================================================
public final void setUp() throws Exception {
super.setUp();
}
public static void main(String[] args) { public static void main(String[] args) {
junit.textui.TestRunner.run(SecurityContextHolderAwareRequestFilterTests.class); junit.textui.TestRunner.run(SecurityContextHolderAwareRequestFilterTests.class);
} }
public final void setUp() throws Exception {
super.setUp();
}
public void testCorrectOperation() throws Exception { public void testCorrectOperation() throws Exception {
SecurityContextHolderAwareRequestFilter filter = new SecurityContextHolderAwareRequestFilter(); SecurityContextHolderAwareRequestFilter filter = new SecurityContextHolderAwareRequestFilter();
filter.init(new MockFilterConfig()); filter.init(new MockFilterConfig());
filter.doFilter(new MockHttpServletRequest(null, null), null, filter.doFilter(new MockHttpServletRequest(null, null), null,
new MockFilterChain(SecurityContextHolderAwareRequestWrapper.class)); new MockFilterChain(SavedRequestAwareWrapper.class));
// Now re-execute the filter, ensuring our replacement wrapper is still used // Now re-execute the filter, ensuring our replacement wrapper is still used
filter.doFilter(new MockHttpServletRequest(null, null), null, filter.doFilter(new MockHttpServletRequest(null, null), null,
new MockFilterChain(SecurityContextHolderAwareRequestWrapper.class)); new MockFilterChain(SavedRequestAwareWrapper.class));
filter.destroy(); filter.destroy();
} }

View File

@ -21,7 +21,7 @@
<value> <value>
CONVERT_URL_TO_LOWERCASE_BEFORE_COMPARISON CONVERT_URL_TO_LOWERCASE_BEFORE_COMPARISON
PATTERN_TYPE_APACHE_ANT PATTERN_TYPE_APACHE_ANT
/**=httpSessionContextIntegrationFilter,logoutFilter,authenticationProcessingFilter,basicProcessingFilter,rememberMeProcessingFilter,anonymousProcessingFilter,switchUserProcessingFilter,exceptionTranslationFilter,filterInvocationInterceptor /**=httpSessionContextIntegrationFilter,logoutFilter,authenticationProcessingFilter,basicProcessingFilter,securityContextHolderAwareRequestFilter,rememberMeProcessingFilter,anonymousProcessingFilter,switchUserProcessingFilter,exceptionTranslationFilter,filterInvocationInterceptor
</value> </value>
</property> </property>
</bean> </bean>
@ -113,6 +113,8 @@
</constructor-arg> </constructor-arg>
</bean> </bean>
<bean id="securityContextHolderAwareRequestFilter" class="org.acegisecurity.wrapper.SecurityContextHolderAwareRequestFilter"/>
<!-- ===================== HTTP CHANNEL REQUIREMENTS ==================== --> <!-- ===================== HTTP CHANNEL REQUIREMENTS ==================== -->
<!-- You will need to uncomment the "Acegi Channel Processing Filter" <!-- You will need to uncomment the "Acegi Channel Processing Filter"