From d125569bd6ebd3a2951c06f44b40cf3adf47887a Mon Sep 17 00:00:00 2001 From: Ben Alex Date: Fri, 28 Apr 2006 05:05:35 +0000 Subject: [PATCH] SEC-29: Save POST parameters on AuthenticationEntryPoint redirect. --- .../intercept/web/FilterInvocation.java | 27 +- .../ChannelProcessingFilter.java | 58 +-- .../ui/AbstractProcessingFilter.java | 31 +- .../ui/ExceptionTranslationFilter.java | 30 +- .../ui/savedrequest/Enumerator.java | 152 +++++++ .../ui/savedrequest/FastHttpDateFormat.java | 234 ++++++++++ .../ui/savedrequest/SavedRequest.java | 362 ++++++++++++++++ .../ui/savedrequest/package.html | 6 + .../java/org/acegisecurity/util/UrlUtils.java | 130 ++++++ .../wrapper/SavedRequestAwareWrapper.java | 409 ++++++++++++++++++ ...curityContextHolderAwareRequestFilter.java | 67 ++- .../intercept/web/FilterInvocationTests.java | 28 +- .../ui/AbstractProcessingFilterTests.java | 24 +- .../ui/ExceptionTranslationFilterTests.java | 11 +- ...yContextHolderAwareRequestFilterTests.java | 18 +- .../applicationContext-acegi-security.xml | 4 +- 16 files changed, 1475 insertions(+), 116 deletions(-) create mode 100644 core/src/main/java/org/acegisecurity/ui/savedrequest/Enumerator.java create mode 100644 core/src/main/java/org/acegisecurity/ui/savedrequest/FastHttpDateFormat.java create mode 100644 core/src/main/java/org/acegisecurity/ui/savedrequest/SavedRequest.java create mode 100644 core/src/main/java/org/acegisecurity/ui/savedrequest/package.html create mode 100644 core/src/main/java/org/acegisecurity/util/UrlUtils.java create mode 100644 core/src/main/java/org/acegisecurity/wrapper/SavedRequestAwareWrapper.java diff --git a/core/src/main/java/org/acegisecurity/intercept/web/FilterInvocation.java b/core/src/main/java/org/acegisecurity/intercept/web/FilterInvocation.java index e88af85bd5..4334017a42 100644 --- a/core/src/main/java/org/acegisecurity/intercept/web/FilterInvocation.java +++ b/core/src/main/java/org/acegisecurity/intercept/web/FilterInvocation.java @@ -1,4 +1,4 @@ -/* Copyright 2004, 2005 Acegi Technology Pty Limited +/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,8 @@ package org.acegisecurity.intercept.web; +import org.acegisecurity.util.UrlUtils; + import javax.servlet.FilterChain; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; @@ -88,10 +90,7 @@ public class FilterInvocation { * @return the full URL of this request */ public String getFullRequestUrl() { - return getHttpRequest().getScheme() + "://" - + getHttpRequest().getServerName() + ":" - + getHttpRequest().getServerPort() + getHttpRequest().getContextPath() - + getRequestUrl(); + return UrlUtils.getFullRequestUrl(this); } public HttpServletRequest getHttpRequest() { @@ -106,19 +105,13 @@ public class FilterInvocation { return request; } + /** + * Obtains the web application-specific fragment of the URL. + * + * @return the URL, excluding any server name, context path or servlet path + */ public String getRequestUrl() { - String pathInfo = getHttpRequest().getPathInfo(); - String queryString = getHttpRequest().getQueryString(); - - String uri = getHttpRequest().getServletPath(); - - if (uri == null) { - uri = getHttpRequest().getRequestURI(); - uri = uri.substring(getHttpRequest().getContextPath().length()); - } - - return uri + ((pathInfo == null) ? "" : pathInfo) - + ((queryString == null) ? "" : ("?" + queryString)); + return UrlUtils.getRequestUrl(this); } public ServletResponse getResponse() { diff --git a/core/src/main/java/org/acegisecurity/securechannel/ChannelProcessingFilter.java b/core/src/main/java/org/acegisecurity/securechannel/ChannelProcessingFilter.java index 4ce8bf79a6..199ce310fe 100644 --- a/core/src/main/java/org/acegisecurity/securechannel/ChannelProcessingFilter.java +++ b/core/src/main/java/org/acegisecurity/securechannel/ChannelProcessingFilter.java @@ -1,4 +1,4 @@ -/* Copyright 2004 Acegi Technology Pty Limited +/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.acegisecurity.securechannel; import org.acegisecurity.ConfigAttribute; import org.acegisecurity.ConfigAttributeDefinition; + import org.acegisecurity.intercept.web.FilterInvocation; import org.acegisecurity.intercept.web.FilterInvocationDefinitionSource; @@ -24,6 +25,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.beans.factory.InitializingBean; + import org.springframework.util.Assert; import java.io.IOException; @@ -78,34 +80,19 @@ public class ChannelProcessingFilter implements InitializingBean, Filter { //~ Methods ================================================================ - public void setChannelDecisionManager( - ChannelDecisionManager channelDecisionManager) { - this.channelDecisionManager = channelDecisionManager; - } - - public ChannelDecisionManager getChannelDecisionManager() { - return channelDecisionManager; - } - - public void setFilterInvocationDefinitionSource( - FilterInvocationDefinitionSource filterInvocationDefinitionSource) { - this.filterInvocationDefinitionSource = filterInvocationDefinitionSource; - } - - public FilterInvocationDefinitionSource getFilterInvocationDefinitionSource() { - return filterInvocationDefinitionSource; - } - public void afterPropertiesSet() throws Exception { - Assert.notNull(filterInvocationDefinitionSource, "filterInvocationDefinitionSource must be specified"); - Assert.notNull(channelDecisionManager, "channelDecisionManager must be specified"); + Assert.notNull(filterInvocationDefinitionSource, + "filterInvocationDefinitionSource must be specified"); + Assert.notNull(channelDecisionManager, + "channelDecisionManager must be specified"); Iterator iter = this.filterInvocationDefinitionSource - .getConfigAttributeDefinitions(); + .getConfigAttributeDefinitions(); if (iter == null) { if (logger.isWarnEnabled()) { - logger.warn("Could not validate configuration attributes as the FilterInvocationDefinitionSource did not return a ConfigAttributeDefinition Iterator"); + logger.warn( + "Could not validate configuration attributes as the FilterInvocationDefinitionSource did not return a ConfigAttributeDefinition Iterator"); } return; @@ -115,7 +102,7 @@ public class ChannelProcessingFilter implements InitializingBean, Filter { while (iter.hasNext()) { ConfigAttributeDefinition def = (ConfigAttributeDefinition) iter - .next(); + .next(); Iterator attributes = def.getConfigAttributes(); while (attributes.hasNext()) { @@ -132,7 +119,8 @@ public class ChannelProcessingFilter implements InitializingBean, Filter { logger.info("Validated configuration attributes"); } } else { - throw new IllegalArgumentException("Unsupported configuration attributes: " + set.toString()); + throw new IllegalArgumentException( + "Unsupported configuration attributes: " + set.toString()); } } @@ -154,7 +142,7 @@ public class ChannelProcessingFilter implements InitializingBean, Filter { if (attr != null) { if (logger.isDebugEnabled()) { - logger.debug("Request: " + fi.getFullRequestUrl() + logger.debug("Request: " + fi.toString() + "; ConfigAttributes: " + attr.toString()); } @@ -168,5 +156,23 @@ public class ChannelProcessingFilter implements InitializingBean, Filter { chain.doFilter(request, response); } + public ChannelDecisionManager getChannelDecisionManager() { + return channelDecisionManager; + } + + public FilterInvocationDefinitionSource getFilterInvocationDefinitionSource() { + return filterInvocationDefinitionSource; + } + public void init(FilterConfig filterConfig) throws ServletException {} + + public void setChannelDecisionManager( + ChannelDecisionManager channelDecisionManager) { + this.channelDecisionManager = channelDecisionManager; + } + + public void setFilterInvocationDefinitionSource( + FilterInvocationDefinitionSource filterInvocationDefinitionSource) { + this.filterInvocationDefinitionSource = filterInvocationDefinitionSource; + } } diff --git a/core/src/main/java/org/acegisecurity/ui/AbstractProcessingFilter.java b/core/src/main/java/org/acegisecurity/ui/AbstractProcessingFilter.java index fbd29469ac..85491fd80d 100644 --- a/core/src/main/java/org/acegisecurity/ui/AbstractProcessingFilter.java +++ b/core/src/main/java/org/acegisecurity/ui/AbstractProcessingFilter.java @@ -26,6 +26,7 @@ import org.acegisecurity.event.authentication.InteractiveAuthenticationSuccessEv import org.acegisecurity.ui.rememberme.NullRememberMeServices; import org.acegisecurity.ui.rememberme.RememberMeServices; +import org.acegisecurity.ui.savedrequest.SavedRequest; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -78,10 +79,12 @@ import javax.servlet.http.HttpServletResponse; *
  • * defaultTargetUrl indicates the URL that should be used for * redirection if the HttpSession attribute named {@link - * #ACEGI_SECURITY_TARGET_URL_KEY} does not indicate the target URL once - * authentication is completed successfully. eg: /. This will be - * treated as relative to the web-app's context path, and should include the - * leading /. + * #ACEGI_SAVED_REQUEST_KEY} does not indicate the target URL once + * authentication is completed successfully. eg: /. The + * defaultTargetUrl will be treated as relative to the web-app's + * context path, and should include the leading /. Alternatively, + * inclusion of a scheme name (eg http:// or https://) as the prefix will + * denote a fully-qualified URL and this is also supported. *
  • *
  • * authenticationFailureUrl indicates the URL that should be used @@ -95,8 +98,8 @@ import javax.servlet.http.HttpServletResponse; *
  • * alwaysUseDefaultTargetUrl causes successful authentication to * always redirect to the defaultTargetUrl, even if the - * HttpSession attribute named {@link - * #ACEGI_SECURITY_TARGET_URL_KEY} defines the intended target URL. + * HttpSession attribute named {@link #ACEGI_SAVED_REQUEST_KEY} + * defines the intended target URL. *
  • * * @@ -132,12 +135,15 @@ import javax.servlet.http.HttpServletResponse; * recorded via an AuthenticationManager-specific application * event. *

    + * + * @author Ben Alex + * @version $Id$ */ public abstract class AbstractProcessingFilter implements Filter, InitializingBean, ApplicationEventPublisherAware, MessageSourceAware { //~ Static fields/initializers ============================================= - public static final String ACEGI_SECURITY_TARGET_URL_KEY = "ACEGI_SECURITY_TARGET_URL"; + public static final String ACEGI_SAVED_REQUEST_KEY = "ACEGI_SAVED_REQUEST_KEY"; public static final String ACEGI_SECURITY_LAST_EXCEPTION_KEY = "ACEGI_SECURITY_LAST_EXCEPTION"; //~ Instance fields ======================================================== @@ -303,6 +309,13 @@ public abstract class AbstractProcessingFilter implements Filter, return continueChainBeforeSuccessfulAuthentication; } + public static String obtainFullRequestUrl(HttpServletRequest request) { + SavedRequest savedRequest = (SavedRequest) request.getSession() + .getAttribute(AbstractProcessingFilter.ACEGI_SAVED_REQUEST_KEY); + + return (savedRequest == null) ? null : savedRequest.getFullRequestUrl(); + } + protected void onPreAuthentication(HttpServletRequest request, HttpServletResponse response) throws AuthenticationException, IOException {} @@ -428,9 +441,7 @@ public abstract class AbstractProcessingFilter implements Filter, + authResult + "'"); } - String targetUrl = (String) request.getSession() - .getAttribute(ACEGI_SECURITY_TARGET_URL_KEY); - request.getSession().removeAttribute(ACEGI_SECURITY_TARGET_URL_KEY); + String targetUrl = obtainFullRequestUrl(request); if (alwaysUseDefaultTargetUrl == true) { targetUrl = null; diff --git a/core/src/main/java/org/acegisecurity/ui/ExceptionTranslationFilter.java b/core/src/main/java/org/acegisecurity/ui/ExceptionTranslationFilter.java index 8c40a42010..b4b47ac057 100644 --- a/core/src/main/java/org/acegisecurity/ui/ExceptionTranslationFilter.java +++ b/core/src/main/java/org/acegisecurity/ui/ExceptionTranslationFilter.java @@ -24,7 +24,7 @@ import org.acegisecurity.InsufficientAuthenticationException; import org.acegisecurity.context.SecurityContextHolder; -import org.acegisecurity.intercept.web.FilterInvocation; +import org.acegisecurity.ui.savedrequest.SavedRequest; import org.acegisecurity.util.PortResolver; import org.acegisecurity.util.PortResolverImpl; @@ -250,34 +250,20 @@ public class ExceptionTranslationFilter implements Filter, InitializingBean { AuthenticationException reason) throws ServletException, IOException { HttpServletRequest httpRequest = (HttpServletRequest) request; - int port = portResolver.getServerPort(httpRequest); - boolean includePort = true; - - if ("http".equals(httpRequest.getScheme().toLowerCase()) - && (port == 80)) { - includePort = false; - } - - if ("https".equals(httpRequest.getScheme().toLowerCase()) - && (port == 443)) { - includePort = false; - } - - String targetUrl = httpRequest.getScheme() + "://" - + httpRequest.getServerName() + ((includePort) ? (":" + port) : "") - + httpRequest.getContextPath() - + new FilterInvocation(request, response, chain).getRequestUrl(); + SavedRequest savedRequest = new SavedRequest(httpRequest, portResolver); if (logger.isDebugEnabled()) { logger.debug( - "Authentication entry point being called; target URL added to Session: " - + targetUrl); + "Authentication entry point being called; SavedRequest added to Session: " + + savedRequest); } if (createSessionAllowed) { + // Store the HTTP request itself. Used by AbstractProcessingFilter + // for redirection after successful authentication (SEC-29) httpRequest.getSession() - .setAttribute(AbstractProcessingFilter.ACEGI_SECURITY_TARGET_URL_KEY, - targetUrl); + .setAttribute(AbstractProcessingFilter.ACEGI_SAVED_REQUEST_KEY, + savedRequest); } // SEC-112: Clear the SecurityContextHolder's Authentication, as the diff --git a/core/src/main/java/org/acegisecurity/ui/savedrequest/Enumerator.java b/core/src/main/java/org/acegisecurity/ui/savedrequest/Enumerator.java new file mode 100644 index 0000000000..bf88c49c9d --- /dev/null +++ b/core/src/main/java/org/acegisecurity/ui/savedrequest/Enumerator.java @@ -0,0 +1,152 @@ +/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.acegisecurity.ui.savedrequest; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Enumeration; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; + + +/** + *

    + * Adapter that wraps an Enumeration around a Java 2 collection + * Iterator. + *

    + * + *

    + * Constructors are provided to easily create such wrappers. + *

    + * + *

    + * This class is based on code in Apache Tomcat. + *

    + * + * @author Craig McClanahan + * @author Andrey Grebnev + * @version $Id$ + */ +public class Enumerator implements Enumeration { + //~ Instance fields ======================================================== + + /** + * The Iterator over which the Enumeration + * represented by this class actually operates. + */ + private Iterator iterator = null; + + //~ Constructors =========================================================== + + /** + * Return an Enumeration over the values of the specified Collection. + * + * @param collection Collection whose values should be enumerated + */ + public Enumerator(Collection collection) { + this(collection.iterator()); + } + + /** + * Return an Enumeration over the values of the specified Collection. + * + * @param collection Collection whose values should be enumerated + * @param clone true to clone iterator + */ + public Enumerator(Collection collection, boolean clone) { + this(collection.iterator(), clone); + } + + /** + * Return an Enumeration over the values returned by the specified + * Iterator. + * + * @param iterator Iterator to be wrapped + */ + public Enumerator(Iterator iterator) { + super(); + this.iterator = iterator; + } + + /** + * Return an Enumeration over the values returned by the specified + * Iterator. + * + * @param iterator Iterator to be wrapped + * @param clone true to clone iterator + */ + public Enumerator(Iterator iterator, boolean clone) { + super(); + + if (!clone) { + this.iterator = iterator; + } else { + List list = new ArrayList(); + + while (iterator.hasNext()) { + list.add(iterator.next()); + } + + this.iterator = list.iterator(); + } + } + + /** + * Return an Enumeration over the values of the specified Map. + * + * @param map Map whose values should be enumerated + */ + public Enumerator(Map map) { + this(map.values().iterator()); + } + + /** + * Return an Enumeration over the values of the specified Map. + * + * @param map Map whose values should be enumerated + * @param clone true to clone iterator + */ + public Enumerator(Map map, boolean clone) { + this(map.values().iterator(), clone); + } + + //~ Methods ================================================================ + + /** + * Tests if this enumeration contains more elements. + * + * @return true if and only if this enumeration object + * contains at least one more element to provide, + * false otherwise + */ + public boolean hasMoreElements() { + return (iterator.hasNext()); + } + + /** + * Returns the next element of this enumeration if this enumeration has at + * least one more element to provide. + * + * @return the next element of this enumeration + * + * @exception NoSuchElementException if no more elements exist + */ + public Object nextElement() throws NoSuchElementException { + return (iterator.next()); + } +} diff --git a/core/src/main/java/org/acegisecurity/ui/savedrequest/FastHttpDateFormat.java b/core/src/main/java/org/acegisecurity/ui/savedrequest/FastHttpDateFormat.java new file mode 100644 index 0000000000..4c0cac39f7 --- /dev/null +++ b/core/src/main/java/org/acegisecurity/ui/savedrequest/FastHttpDateFormat.java @@ -0,0 +1,234 @@ +/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.acegisecurity.ui.savedrequest; + +import java.text.DateFormat; +import java.text.ParseException; +import java.text.SimpleDateFormat; + +import java.util.Date; +import java.util.HashMap; +import java.util.Locale; +import java.util.TimeZone; + + +/** + *

    + * Utility class to generate HTTP dates. + *

    + * + *

    + * This class is based on code in Apache Tomcat. + *

    + * + * @author Remy Maucherat + * @author Andrey Grebnev + * @version $Id$ + */ +public class FastHttpDateFormat { + //~ Static fields/initializers ============================================= + + /** HTTP date format. */ + protected static final SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz", + Locale.US); + + /** + * The set of SimpleDateFormat formats to use in + * getDateHeader(). + */ + protected static final SimpleDateFormat[] formats = {new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz", + Locale.US), new SimpleDateFormat("EEEEEE, dd-MMM-yy HH:mm:ss zzz", + Locale.US), new SimpleDateFormat("EEE MMMM d HH:mm:ss yyyy", + Locale.US)}; + + /** GMT timezone - all HTTP dates are on GMT */ + protected final static TimeZone gmtZone = TimeZone.getTimeZone("GMT"); + + static { + format.setTimeZone(gmtZone); + + formats[0].setTimeZone(gmtZone); + formats[1].setTimeZone(gmtZone); + formats[2].setTimeZone(gmtZone); + } + + /** Instant on which the currentDate object was generated. */ + protected static long currentDateGenerated = 0L; + + /** Current formatted date. */ + protected static String currentDate = null; + + /** Formatter cache. */ + protected static final HashMap formatCache = new HashMap(); + + /** Parser cache. */ + protected static final HashMap parseCache = new HashMap(); + + //~ Methods ================================================================ + + /** + * Formats a specified date to HTTP format. If local format is not + * null, it's used instead. + * + * @param value Date value to format + * @param threadLocalformat The format to use (or null -- then + * HTTP format will be used) + * + * @return Formatted date + */ + public static final String formatDate(long value, + DateFormat threadLocalformat) { + String cachedDate = null; + Long longValue = new Long(value); + + try { + cachedDate = (String) formatCache.get(longValue); + } catch (Exception e) {} + + if (cachedDate != null) { + return cachedDate; + } + + String newDate = null; + Date dateValue = new Date(value); + + if (threadLocalformat != null) { + newDate = threadLocalformat.format(dateValue); + + synchronized (formatCache) { + updateCache(formatCache, longValue, newDate); + } + } else { + synchronized (formatCache) { + newDate = format.format(dateValue); + updateCache(formatCache, longValue, newDate); + } + } + + return newDate; + } + + /** + * Gets the current date in HTTP format. + * + * @return Current date in HTTP format + */ + public static final String getCurrentDate() { + long now = System.currentTimeMillis(); + + if ((now - currentDateGenerated) > 1000) { + synchronized (format) { + if ((now - currentDateGenerated) > 1000) { + currentDateGenerated = now; + currentDate = format.format(new Date(now)); + } + } + } + + return currentDate; + } + + /** + * Parses date with given formatters. + * + * @param value The string to parse + * @param formats Array of formats to use + * + * @return Parsed date (or null if no formatter mached) + */ + private static final Long internalParseDate(String value, + DateFormat[] formats) { + Date date = null; + + for (int i = 0; (date == null) && (i < formats.length); i++) { + try { + date = formats[i].parse(value); + } catch (ParseException e) { + ; + } + } + + if (date == null) { + return null; + } + + return new Long(date.getTime()); + } + + /** + * Tries to parse the given date as an HTTP date. If local format list is + * not null, it's used instead. + * + * @param value The string to parse + * @param threadLocalformats Array of formats to use for parsing. If + * null, HTTP formats are used. + * + * @return Parsed date (or -1 if error occured) + */ + public static final long parseDate(String value, + DateFormat[] threadLocalformats) { + Long cachedDate = null; + + try { + cachedDate = (Long) parseCache.get(value); + } catch (Exception e) {} + + if (cachedDate != null) { + return cachedDate.longValue(); + } + + Long date = null; + + if (threadLocalformats != null) { + date = internalParseDate(value, threadLocalformats); + + synchronized (parseCache) { + updateCache(parseCache, value, date); + } + } else { + synchronized (parseCache) { + date = internalParseDate(value, formats); + updateCache(parseCache, value, date); + } + } + + if (date == null) { + return (-1L); + } else { + return date.longValue(); + } + } + + /** + * Updates cache. + * + * @param cache Cache to be updated + * @param key Key to be updated + * @param value New value + */ + private static final void updateCache(HashMap cache, Object key, + Object value) { + if (value == null) { + return; + } + + if (cache.size() > 1000) { + cache.clear(); + } + + cache.put(key, value); + } +} diff --git a/core/src/main/java/org/acegisecurity/ui/savedrequest/SavedRequest.java b/core/src/main/java/org/acegisecurity/ui/savedrequest/SavedRequest.java new file mode 100644 index 0000000000..476d0be28b --- /dev/null +++ b/core/src/main/java/org/acegisecurity/ui/savedrequest/SavedRequest.java @@ -0,0 +1,362 @@ +/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.acegisecurity.ui.savedrequest; + +import org.acegisecurity.util.PortResolver; +import org.acegisecurity.util.UrlUtils; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.util.Assert; + +import java.util.ArrayList; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import javax.servlet.http.Cookie; +import javax.servlet.http.HttpServletRequest; + + +/** + * Represents central information from a HttpServletRequest. + * + *

    + * This class is used by {@link org.acegisecurity.ui.AbstractProcessingFilter} + * and {@link org.acegisecurity.wrapper.SavedRequestAwareWrapper} to reproduce + * the request after successful authentication. An instance of this class is + * stored at the time of an authentication exception by {@link + * org.acegisecurity.ui.ExceptionTranslationFilter}. + *

    + * + *

    + * IMPLEMENTATION NOTE: It is assumed that this object is accessed + * only from the context of a single thread, so no synchronization around + * internal collection classes is performed. + *

    + * + *

    + * This class is based on code in Apache Tomcat. + *

    + * + * @author Craig McClanahan + * @author Andrey Grebnev + * @author Ben Alex + * @version $Id$ + */ +public class SavedRequest { + //~ Static fields/initializers ============================================= + + protected static final Log logger = LogFactory.getLog(SavedRequest.class); + + //~ Instance fields ======================================================== + + private ArrayList cookies = new ArrayList(); + private ArrayList locales = new ArrayList(); + private HashMap headers = new HashMap(); + private HashMap parameters = new HashMap(); + private String contextPath; + private String method; + private String pathInfo; + private String queryString; + private String requestURI; + private String requestURL; + private String scheme; + private String serverName; + private String servletPath; + private int serverPort; + + //~ Constructors =========================================================== + + public SavedRequest(HttpServletRequest request, PortResolver portResolver) { + Assert.notNull(request, "Request required"); + Assert.notNull(portResolver, "PortResolver required"); + + // Cookies + Cookie[] cookies = request.getCookies(); + + if (cookies != null) { + for (int i = 0; i < cookies.length; i++) { + this.addCookie(cookies[i]); + } + } + + // Headers + Enumeration names = request.getHeaderNames(); + + while (names.hasMoreElements()) { + String name = (String) names.nextElement(); + Enumeration values = request.getHeaders(name); + + while (values.hasMoreElements()) { + String value = (String) values.nextElement(); + this.addHeader(name, value); + } + } + + // Locales + Enumeration locales = request.getLocales(); + + while (locales.hasMoreElements()) { + Locale locale = (Locale) locales.nextElement(); + this.addLocale(locale); + } + + // Parameters + Map parameters = request.getParameterMap(); + Iterator paramNames = parameters.keySet().iterator(); + + while (paramNames.hasNext()) { + String paramName = (String) paramNames.next(); + String[] paramValues = (String[]) parameters.get(paramName); + this.addParameter(paramName, paramValues); + } + + // Primitives + this.method = request.getMethod(); + this.pathInfo = request.getPathInfo(); + this.queryString = request.getQueryString(); + this.requestURI = request.getRequestURI(); + this.serverPort = portResolver.getServerPort(request); + this.requestURL = request.getRequestURL().toString(); + this.scheme = request.getScheme(); + this.serverName = request.getServerName(); + this.contextPath = request.getContextPath(); + this.servletPath = request.getServletPath(); + } + + //~ Methods ================================================================ + + private void addCookie(Cookie cookie) { + cookies.add(cookie); + } + + private void addHeader(String name, String value) { + ArrayList values = (ArrayList) headers.get(name); + + if (values == null) { + values = new ArrayList(); + headers.put(name, values); + } + + values.add(value); + } + + private void addLocale(Locale locale) { + locales.add(locale); + } + + private void addParameter(String name, String[] values) { + parameters.put(name, values); + } + + /** + * Determines if the current request matches the SavedRequest. + * All URL arguments are considered, but not method (POST/GET), + * cookies, locales, headers or parameters. + * + * @param request DOCUMENT ME! + * @param portResolver DOCUMENT ME! + * + * @return DOCUMENT ME! + */ + public boolean doesRequestMatch(HttpServletRequest request, + PortResolver portResolver) { + Assert.notNull(request, "Request required"); + Assert.notNull(portResolver, "PortResolver required"); + + if (!propertyEquals("pathInfo", this.pathInfo, request.getPathInfo())) { + return false; + } + + if (!propertyEquals("queryString", this.queryString, + request.getQueryString())) { + return false; + } + + if (!propertyEquals("requestURI", this.requestURI, + request.getRequestURI())) { + return false; + } + + if (!propertyEquals("serverPort", new Integer(this.serverPort), + new Integer(portResolver.getServerPort(request)))) { + return false; + } + + if (!propertyEquals("requestURL", this.requestURL, + request.getRequestURL().toString())) { + return false; + } + + if (!propertyEquals("scheme", this.scheme, request.getScheme())) { + return false; + } + + if (!propertyEquals("serverName", this.serverName, + request.getServerName())) { + return false; + } + + if (!propertyEquals("contextPath", this.contextPath, + request.getContextPath())) { + return false; + } + + if (!propertyEquals("servletPath", this.servletPath, + request.getServletPath())) { + return false; + } + + return true; + } + + public String getContextPath() { + return contextPath; + } + + public List getCookies() { + return cookies; + } + + /** + * Indicates the URL that the user agent used for this request. + * + * @return the full URL of this request + */ + public String getFullRequestUrl() { + return UrlUtils.getFullRequestUrl(this); + } + + public Iterator getHeaderNames() { + return (headers.keySet().iterator()); + } + + public Iterator getHeaderValues(String name) { + ArrayList values = (ArrayList) headers.get(name); + + if (values == null) { + return ((new ArrayList()).iterator()); + } else { + return (values.iterator()); + } + } + + public Iterator getLocales() { + return (locales.iterator()); + } + + public String getMethod() { + return (this.method); + } + + public Map getParameterMap() { + return parameters; + } + + public Iterator getParameterNames() { + return (parameters.keySet().iterator()); + } + + public String[] getParameterValues(String name) { + return ((String[]) parameters.get(name)); + } + + public String getPathInfo() { + return pathInfo; + } + + public String getQueryString() { + return (this.queryString); + } + + public String getRequestURI() { + return (this.requestURI); + } + + public String getRequestURL() { + return requestURL; + } + + /** + * Obtains the web application-specific fragment of the URL. + * + * @return the URL, excluding any server name, context path or servlet path + */ + public String getRequestUrl() { + return UrlUtils.getRequestUrl(this); + } + + public String getScheme() { + return scheme; + } + + public String getServerName() { + return serverName; + } + + public int getServerPort() { + return serverPort; + } + + public String getServletPath() { + return servletPath; + } + + private boolean propertyEquals(String log, Object arg1, Object arg2) { + if ((arg1 == null) && (arg2 == null)) { + if (logger.isDebugEnabled()) { + logger.debug(log + ": both null (property equals)"); + } + + return true; + } + + if (((arg1 == null) && (arg2 != null)) + || ((arg1 != null) && (arg2 == null))) { + if (logger.isDebugEnabled()) { + logger.debug(log + ": arg1=" + arg1 + "; arg2=" + arg2 + + " (property not equals)"); + } + + return false; + } + + if (arg1.equals(arg2)) { + if (logger.isDebugEnabled()) { + logger.debug(log + ": arg1=" + arg1 + "; arg2=" + arg2 + + " (property equals)"); + } + + return true; + } else { + if (logger.isDebugEnabled()) { + logger.debug(log + ": arg1=" + arg1 + "; arg2=" + arg2 + + " (property not equals)"); + } + + return false; + } + } + + public String toString() { + return "SavedRequest[" + getFullRequestUrl() + "]"; + } +} diff --git a/core/src/main/java/org/acegisecurity/ui/savedrequest/package.html b/core/src/main/java/org/acegisecurity/ui/savedrequest/package.html new file mode 100644 index 0000000000..6264c5ee5a --- /dev/null +++ b/core/src/main/java/org/acegisecurity/ui/savedrequest/package.html @@ -0,0 +1,6 @@ + + +Stores a HttpServletRequest so that it can subsequently be emulated by the +SavedRequestAwareWrapper. + + diff --git a/core/src/main/java/org/acegisecurity/util/UrlUtils.java b/core/src/main/java/org/acegisecurity/util/UrlUtils.java new file mode 100644 index 0000000000..ef9cbe38f0 --- /dev/null +++ b/core/src/main/java/org/acegisecurity/util/UrlUtils.java @@ -0,0 +1,130 @@ +/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.acegisecurity.util; + +import org.acegisecurity.intercept.web.FilterInvocation; +import org.acegisecurity.ui.savedrequest.SavedRequest; + +import javax.servlet.http.HttpServletRequest; + + +/** + * Provides static methods for composing URLs. + * + *

    + * Placed into a separate class for visibility, so that changes to URL + * formatting conventions will affect all users. + *

    + * + * @author Ben Alex + * @version $Id$ + */ +public class UrlUtils { + //~ Methods ================================================================ + + /** + * Obtains the full URL the client used to make the request. + * + *

    + * Note that the server port will not be shown if it is the default server + * port for HTTP or HTTPS (ie 80 and 443 respectively). + *

    + * + * @param scheme DOCUMENT ME! + * @param serverName DOCUMENT ME! + * @param serverPort DOCUMENT ME! + * @param contextPath DOCUMENT ME! + * @param requestUrl DOCUMENT ME! + * @param servletPath DOCUMENT ME! + * @param requestURI DOCUMENT ME! + * @param pathInfo DOCUMENT ME! + * @param queryString DOCUMENT ME! + * + * @return the full URL + */ + private static String buildFullRequestUrl(String scheme, String serverName, + int serverPort, String contextPath, String requestUrl, + String servletPath, String requestURI, String pathInfo, + String queryString) { + boolean includePort = true; + + if ("http".equals(scheme.toLowerCase()) && (serverPort == 80)) { + includePort = false; + } + + if ("https".equals(scheme.toLowerCase()) && (serverPort == 443)) { + includePort = false; + } + + return scheme + "://" + serverName + + ((includePort) ? (":" + serverPort) : "") + contextPath + + buildRequestUrl(servletPath, requestURI, contextPath, pathInfo, + queryString); + } + + /** + * Obtains the web application-specific fragment of the URL. + * + * @param servletPath DOCUMENT ME! + * @param requestURI DOCUMENT ME! + * @param contextPath DOCUMENT ME! + * @param pathInfo DOCUMENT ME! + * @param queryString DOCUMENT ME! + * + * @return the URL, excluding any server name, context path or servlet path + */ + private static String buildRequestUrl(String servletPath, + String requestURI, String contextPath, String pathInfo, + String queryString) { + String uri = servletPath; + + if (uri == null) { + uri = requestURI; + uri = uri.substring(contextPath.length()); + } + + return uri + ((pathInfo == null) ? "" : pathInfo) + + ((queryString == null) ? "" : ("?" + queryString)); + } + + public static String getFullRequestUrl(FilterInvocation fi) { + HttpServletRequest r = fi.getHttpRequest(); + + return buildFullRequestUrl(r.getScheme(), r.getServerName(), + r.getServerPort(), r.getContextPath(), + r.getRequestURL().toString(), r.getServletPath(), + r.getRequestURI(), r.getPathInfo(), r.getQueryString()); + } + + public static String getFullRequestUrl(SavedRequest sr) { + return buildFullRequestUrl(sr.getScheme(), sr.getServerName(), + sr.getServerPort(), sr.getContextPath(), sr.getRequestURL(), + sr.getServletPath(), sr.getRequestURI(), sr.getPathInfo(), + sr.getQueryString()); + } + + public static String getRequestUrl(FilterInvocation fi) { + HttpServletRequest r = fi.getHttpRequest(); + + return buildRequestUrl(r.getServletPath(), r.getRequestURI(), + r.getContextPath(), r.getPathInfo(), r.getQueryString()); + } + + public static String getRequestUrl(SavedRequest sr) { + return buildRequestUrl(sr.getServletPath(), sr.getRequestURI(), + sr.getContextPath(), sr.getPathInfo(), sr.getQueryString()); + } +} diff --git a/core/src/main/java/org/acegisecurity/wrapper/SavedRequestAwareWrapper.java b/core/src/main/java/org/acegisecurity/wrapper/SavedRequestAwareWrapper.java new file mode 100644 index 0000000000..4271d950be --- /dev/null +++ b/core/src/main/java/org/acegisecurity/wrapper/SavedRequestAwareWrapper.java @@ -0,0 +1,409 @@ +/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.acegisecurity.wrapper; + +import org.acegisecurity.ui.AbstractProcessingFilter; +import org.acegisecurity.ui.savedrequest.Enumerator; +import org.acegisecurity.ui.savedrequest.FastHttpDateFormat; +import org.acegisecurity.ui.savedrequest.SavedRequest; + +import org.acegisecurity.util.PortResolver; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import java.text.SimpleDateFormat; + +import java.util.ArrayList; +import java.util.Enumeration; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.TimeZone; + +import javax.servlet.http.Cookie; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpSession; + + +/** + * Provides request parameters, headers and cookies from either an original + * request or a saved request. + * + *

    + * Note that not all request parameters in the original request are emulated by + * this wrapper. Nevertheless, the important data from the original request is + * emulated and this should prove adequate for most purposes (in particular + * standard HTTP GET and POST operations). + *

    + * + *

    + * Added into a request by {@link + * org.acegisecurity.wrapper.SecurityContextHolderAwareRequestFilter}. + *

    + * + * @author Andrey Grebnev + * @author Ben Alex + * @version $Id$ + */ +public class SavedRequestAwareWrapper + extends SecurityContextHolderAwareRequestWrapper { + //~ Static fields/initializers ============================================= + + protected static final Log logger = LogFactory.getLog(SavedRequestAwareWrapper.class); + protected static final TimeZone GMT_ZONE = TimeZone.getTimeZone("GMT"); + + /** The default Locale if none are specified. */ + protected static Locale defaultLocale = Locale.getDefault(); + + //~ Instance fields ======================================================== + + protected SavedRequest savedRequest = null; + + /** + * The set of SimpleDateFormat formats to use in getDateHeader(). Notice + * that because SimpleDateFormat is not thread-safe, we can't declare + * formats[] as a static variable. + */ + protected SimpleDateFormat[] formats = new SimpleDateFormat[3]; + + //~ Constructors =========================================================== + + public SavedRequestAwareWrapper(HttpServletRequest request, + PortResolver portResolver) { + super(request); + + HttpSession session = request.getSession(false); + + if (session == null) { + if (logger.isDebugEnabled()) { + logger.debug( + "Wrapper not replaced; no session available for SavedRequest extraction"); + } + + return; + } + + SavedRequest saved = (SavedRequest) session.getAttribute(AbstractProcessingFilter.ACEGI_SAVED_REQUEST_KEY); + + if ((saved != null) && saved.doesRequestMatch(request, portResolver)) { + if (logger.isDebugEnabled()) { + logger.debug("Wrapper replaced; SavedRequest was: " + saved); + } + + savedRequest = saved; + session.removeAttribute(AbstractProcessingFilter.ACEGI_SAVED_REQUEST_KEY); + + formats[0] = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz", + Locale.US); + formats[1] = new SimpleDateFormat("EEEEEE, dd-MMM-yy HH:mm:ss zzz", + Locale.US); + formats[2] = new SimpleDateFormat("EEE MMMM d HH:mm:ss yyyy", + Locale.US); + + formats[0].setTimeZone(GMT_ZONE); + formats[1].setTimeZone(GMT_ZONE); + formats[2].setTimeZone(GMT_ZONE); + } else { + if (logger.isDebugEnabled()) { + logger.debug("Wrapper not replaced; SavedRequest was: " + saved); + } + } + } + + //~ Methods ================================================================ + + /** + * The default behavior of this method is to return getCookies() on the + * wrapped request object. + * + * @return DOCUMENT ME! + */ + public Cookie[] getCookies() { + if (savedRequest == null) { + return super.getCookies(); + } else { + List cookies = savedRequest.getCookies(); + + return (Cookie[]) cookies.toArray(new Cookie[cookies.size()]); + } + } + + /** + * The default behavior of this method is to return getDateHeader(String + * name) on the wrapped request object. + * + * @param name DOCUMENT ME! + * + * @return DOCUMENT ME! + * + * @throws IllegalArgumentException DOCUMENT ME! + */ + public long getDateHeader(String name) { + if (savedRequest == null) { + return super.getDateHeader(name); + } else { + String value = getHeader(name); + + if (value == null) { + return (-1L); + } + + // Attempt to convert the date header in a variety of formats + long result = FastHttpDateFormat.parseDate(value, formats); + + if (result != (-1L)) { + return result; + } + + throw new IllegalArgumentException(value); + } + } + + /** + * The default behavior of this method is to return getHeader(String name) + * on the wrapped request object. + * + * @param name DOCUMENT ME! + * + * @return DOCUMENT ME! + */ + public String getHeader(String name) { + if (savedRequest == null) { + return super.getHeader(name); + } else { + String header = null; + Iterator iterator = savedRequest.getHeaderValues(name); + + while (iterator.hasNext()) { + header = (String) iterator.next(); + + break; + } + + return header; + } + } + + /** + * The default behavior of this method is to return getHeaderNames() on the + * wrapped request object. + * + * @return DOCUMENT ME! + */ + public Enumeration getHeaderNames() { + if (savedRequest == null) { + return super.getHeaderNames(); + } else { + return new Enumerator(savedRequest.getHeaderNames()); + } + } + + /** + * The default behavior of this method is to return getHeaders(String name) + * on the wrapped request object. + * + * @param name DOCUMENT ME! + * + * @return DOCUMENT ME! + */ + public Enumeration getHeaders(String name) { + if (savedRequest == null) { + return super.getHeaders(name); + } else { + return new Enumerator(savedRequest.getHeaderValues(name)); + } + } + + /** + * The default behavior of this method is to return getIntHeader(String + * name) on the wrapped request object. + * + * @param name DOCUMENT ME! + * + * @return DOCUMENT ME! + */ + public int getIntHeader(String name) { + if (savedRequest == null) { + return super.getIntHeader(name); + } else { + String value = getHeader(name); + + if (value == null) { + return (-1); + } else { + return (Integer.parseInt(value)); + } + } + } + + /** + * The default behavior of this method is to return getLocale() on the + * wrapped request object. + * + * @return DOCUMENT ME! + */ + public Locale getLocale() { + if (savedRequest == null) { + return super.getLocale(); + } else { + Locale locale = null; + Iterator iterator = savedRequest.getLocales(); + + while (iterator.hasNext()) { + locale = (Locale) iterator.next(); + + break; + } + + if (locale == null) { + return defaultLocale; + } else { + return locale; + } + } + } + + /** + * The default behavior of this method is to return getLocales() on the + * wrapped request object. + * + * @return DOCUMENT ME! + */ + public Enumeration getLocales() { + if (savedRequest == null) { + return super.getLocales(); + } else { + Iterator iterator = savedRequest.getLocales(); + + if (iterator.hasNext()) { + return new Enumerator(iterator); + } else { + ArrayList results = new ArrayList(); + results.add(defaultLocale); + + return new Enumerator(results.iterator()); + } + } + } + + /** + * The default behavior of this method is to return getMethod() on the + * wrapped request object. + * + * @return DOCUMENT ME! + */ + public String getMethod() { + if (savedRequest == null) { + return super.getMethod(); + } else { + return savedRequest.getMethod(); + } + } + + /** + * The default behavior of this method is to return getParameter(String + * name) on the wrapped request object. + * + * @param name DOCUMENT ME! + * + * @return DOCUMENT ME! + */ + public String getParameter(String name) { +/* + if (savedRequest == null) { + return super.getParameter(name); + } else { + String value = null; + String[] values = savedRequest.getParameterValues(name); + if (values == null) + return null; + for (int i = 0; i < values.length; i++) { + value = values[i]; + break; + } + return value; + } + */ + + //we do not get value from super.getParameter because there is a bug in Jetty servlet-container + String value = null; + String[] values = null; + + if (savedRequest == null) { + values = super.getParameterValues(name); + } else { + values = savedRequest.getParameterValues(name); + } + + if (values == null) { + return null; + } + + for (int i = 0; i < values.length; i++) { + value = values[i]; + + break; + } + + return value; + } + + /** + * The default behavior of this method is to return getParameterMap() on + * the wrapped request object. + * + * @return DOCUMENT ME! + */ + public Map getParameterMap() { + if (savedRequest == null) { + return super.getParameterMap(); + } else { + return savedRequest.getParameterMap(); + } + } + + /** + * The default behavior of this method is to return getParameterNames() on + * the wrapped request object. + * + * @return DOCUMENT ME! + */ + public Enumeration getParameterNames() { + if (savedRequest == null) { + return super.getParameterNames(); + } else { + return new Enumerator(savedRequest.getParameterNames()); + } + } + + /** + * The default behavior of this method is to return + * getParameterValues(String name) on the wrapped request object. + * + * @param name DOCUMENT ME! + * + * @return DOCUMENT ME! + */ + public String[] getParameterValues(String name) { + if (savedRequest == null) { + return super.getParameterValues(name); + } else { + return savedRequest.getParameterValues(name); + } + } +} diff --git a/core/src/main/java/org/acegisecurity/wrapper/SecurityContextHolderAwareRequestFilter.java b/core/src/main/java/org/acegisecurity/wrapper/SecurityContextHolderAwareRequestFilter.java index ae19ea52af..4415133d58 100644 --- a/core/src/main/java/org/acegisecurity/wrapper/SecurityContextHolderAwareRequestFilter.java +++ b/core/src/main/java/org/acegisecurity/wrapper/SecurityContextHolderAwareRequestFilter.java @@ -1,4 +1,4 @@ -/* Copyright 2004, 2005 Acegi Technology Pty Limited +/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,8 +15,16 @@ package org.acegisecurity.wrapper; +import org.acegisecurity.util.PortResolver; +import org.acegisecurity.util.PortResolverImpl; + +import org.springframework.util.Assert; +import org.springframework.util.ReflectionUtils; + import java.io.IOException; +import java.lang.reflect.Constructor; + import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; @@ -27,13 +35,38 @@ import javax.servlet.http.HttpServletRequest; /** - * A Filter which populates the ServletRequest with - * an {@link SecurityContextHolderAwareRequestWrapper}. + * A Filter which populates the ServletRequest with a + * new request wrapper. + * + *

    + * Several request wrappers are included with the framework. The simplest + * version is {@link SecurityContextHolderAwareRequestWrapper}. A more complex + * and powerful request wrapper is {@link + * org.acegisecurity.wrapper.SavedRequestAwareWrapper}. The latter is also the + * default. + *

    + * + *

    + * To modify the wrapper used, call {@link #setWrapperClass(Class)}. + *

    + * + *

    + * Any request wrapper configured for instantiation by this class must provide + * a public constructor that accepts two arguments, being a + * HttpServletRequest and a PortResolver. + *

    * * @author Orlando Garcia Carmona + * @author Ben Alex * @version $Id$ */ public class SecurityContextHolderAwareRequestFilter implements Filter { + //~ Instance fields ======================================================== + + private Class wrapperClass = SavedRequestAwareWrapper.class; + private Constructor constructor; + private PortResolver portResolver = new PortResolverImpl(); + //~ Methods ================================================================ public void destroy() {} @@ -43,12 +76,36 @@ public class SecurityContextHolderAwareRequestFilter implements Filter { throws IOException, ServletException { HttpServletRequest request = (HttpServletRequest) servletRequest; - if (!(request instanceof SecurityContextHolderAwareRequestWrapper)) { - request = new SecurityContextHolderAwareRequestWrapper(request); + if (!wrapperClass.isAssignableFrom(request.getClass())) { + if (constructor == null) { + try { + constructor = wrapperClass.getConstructor(new Class[] {HttpServletRequest.class, PortResolver.class}); + } catch (Exception ex) { + ReflectionUtils.handleReflectionException(ex); + } + } + + try { + request = (HttpServletRequest) constructor.newInstance(new Object[] {request, portResolver}); + } catch (Exception ex) { + ReflectionUtils.handleReflectionException(ex); + } } filterChain.doFilter(request, servletResponse); } public void init(FilterConfig filterConfig) throws ServletException {} + + public void setPortResolver(PortResolver portResolver) { + Assert.notNull(portResolver, "PortResolver required"); + this.portResolver = portResolver; + } + + public void setWrapperClass(Class wrapperClass) { + Assert.notNull(wrapperClass, "WrapperClass required"); + Assert.isTrue(HttpServletRequest.class.isAssignableFrom(wrapperClass), + "Wrapper must be a HttpServletRequest"); + this.wrapperClass = wrapperClass; + } } diff --git a/core/src/test/java/org/acegisecurity/intercept/web/FilterInvocationTests.java b/core/src/test/java/org/acegisecurity/intercept/web/FilterInvocationTests.java index 3ca684888c..9ef695a263 100644 --- a/core/src/test/java/org/acegisecurity/intercept/web/FilterInvocationTests.java +++ b/core/src/test/java/org/acegisecurity/intercept/web/FilterInvocationTests.java @@ -1,4 +1,4 @@ -/* Copyright 2004, 2005 Acegi Technology Pty Limited +/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,12 +17,14 @@ package org.acegisecurity.intercept.web; import org.acegisecurity.MockFilterChain; -import javax.servlet.ServletRequest; -import javax.servlet.ServletResponse; +import org.jmock.MockObjectTestCase; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; -import org.jmock.MockObjectTestCase; + +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; + /** * Tests {@link FilterInvocation}. @@ -44,14 +46,14 @@ public class FilterInvocationTests extends MockObjectTestCase { //~ Methods ================================================================ - public final void setUp() throws Exception { - super.setUp(); - } - public static void main(String[] args) { junit.textui.TestRunner.run(FilterInvocationTests.class); } + public final void setUp() throws Exception { + super.setUp(); + } + public void testGettersAndStringMethods() { MockHttpServletRequest request = new MockHttpServletRequest(null, null); request.setServletPath("/HelloWorld"); @@ -73,7 +75,7 @@ public class FilterInvocationTests extends MockObjectTestCase { assertEquals("/HelloWorld/some/more/segments.html", fi.getRequestUrl()); assertEquals("FilterInvocation: URL: /HelloWorld/some/more/segments.html", fi.toString()); - assertEquals("http://www.example.com:80/mycontext/HelloWorld/some/more/segments.html", + assertEquals("http://www.example.com/mycontext/HelloWorld/some/more/segments.html", fi.getFullRequestUrl()); } @@ -81,7 +83,7 @@ public class FilterInvocationTests extends MockObjectTestCase { Class clazz = FilterInvocation.class; try { - clazz.getDeclaredConstructor((Class[])null); + clazz.getDeclaredConstructor((Class[]) null); fail("Should have thrown NoSuchMethodException"); } catch (NoSuchMethodException expected) { assertTrue(true); @@ -125,7 +127,7 @@ public class FilterInvocationTests extends MockObjectTestCase { } public void testRejectsServletRequestWhichIsNotHttpServletRequest() { - ServletRequest request = (ServletRequest)newDummy(ServletRequest.class); + ServletRequest request = (ServletRequest) newDummy(ServletRequest.class); MockHttpServletResponse response = new MockHttpServletResponse(); MockFilterChain chain = new MockFilterChain(); @@ -167,7 +169,7 @@ public class FilterInvocationTests extends MockObjectTestCase { FilterInvocation fi = new FilterInvocation(request, response, chain); assertEquals("/HelloWorld?foo=bar", fi.getRequestUrl()); assertEquals("FilterInvocation: URL: /HelloWorld?foo=bar", fi.toString()); - assertEquals("http://www.example.com:80/mycontext/HelloWorld?foo=bar", + assertEquals("http://www.example.com/mycontext/HelloWorld?foo=bar", fi.getFullRequestUrl()); } @@ -185,7 +187,7 @@ public class FilterInvocationTests extends MockObjectTestCase { FilterInvocation fi = new FilterInvocation(request, response, chain); assertEquals("/HelloWorld", fi.getRequestUrl()); assertEquals("FilterInvocation: URL: /HelloWorld", fi.toString()); - assertEquals("http://www.example.com:80/mycontext/HelloWorld", + assertEquals("http://www.example.com/mycontext/HelloWorld", fi.getFullRequestUrl()); } } diff --git a/core/src/test/java/org/acegisecurity/ui/AbstractProcessingFilterTests.java b/core/src/test/java/org/acegisecurity/ui/AbstractProcessingFilterTests.java index 67f86fbe77..b36fca0923 100644 --- a/core/src/test/java/org/acegisecurity/ui/AbstractProcessingFilterTests.java +++ b/core/src/test/java/org/acegisecurity/ui/AbstractProcessingFilterTests.java @@ -30,6 +30,9 @@ import org.acegisecurity.context.SecurityContextHolder; import org.acegisecurity.providers.UsernamePasswordAuthenticationToken; import org.acegisecurity.ui.rememberme.TokenBasedRememberMeServices; +import org.acegisecurity.ui.savedrequest.SavedRequest; + +import org.acegisecurity.util.PortResolverImpl; import org.springframework.mock.web.MockFilterConfig; import org.springframework.mock.web.MockHttpServletRequest; @@ -91,6 +94,16 @@ public class AbstractProcessingFilterTests extends TestCase { junit.textui.TestRunner.run(AbstractProcessingFilterTests.class); } + private SavedRequest makeSavedRequestForUrl() { + MockHttpServletRequest request = createMockRequest(); + request.setServletPath("/some_protected_file.html"); + request.setScheme("http"); + request.setServerName("www.example.com"); + request.setRequestURI("/mycontext/some_protected_file.html"); + + return new SavedRequest(request, new PortResolverImpl()); + } + protected void setUp() throws Exception { super.setUp(); SecurityContextHolder.clearContext(); @@ -399,8 +412,8 @@ public class AbstractProcessingFilterTests extends TestCase { // Setup our HTTP request MockHttpServletRequest request = createMockRequest(); request.getSession() - .setAttribute(AbstractProcessingFilter.ACEGI_SECURITY_TARGET_URL_KEY, - "/my-destination"); + .setAttribute(AbstractProcessingFilter.ACEGI_SAVED_REQUEST_KEY, + makeSavedRequestForUrl()); // Setup our filter configuration MockFilterConfig config = new MockFilterConfig(null); @@ -429,8 +442,8 @@ public class AbstractProcessingFilterTests extends TestCase { // Setup our HTTP request MockHttpServletRequest request = createMockRequest(); request.getSession() - .setAttribute(AbstractProcessingFilter.ACEGI_SECURITY_TARGET_URL_KEY, - "/my-destination"); + .setAttribute(AbstractProcessingFilter.ACEGI_SAVED_REQUEST_KEY, + makeSavedRequestForUrl()); // Setup our filter configuration MockFilterConfig config = new MockFilterConfig(null); @@ -446,7 +459,8 @@ public class AbstractProcessingFilterTests extends TestCase { // Test executeFilterInContainerSimulator(config, filter, request, response, chain); - assertEquals("/my-destination", response.getRedirectedUrl()); + assertEquals(makeSavedRequestForUrl().getFullRequestUrl(), + response.getRedirectedUrl()); assertNotNull(SecurityContextHolder.getContext().getAuthentication()); } diff --git a/core/src/test/java/org/acegisecurity/ui/ExceptionTranslationFilterTests.java b/core/src/test/java/org/acegisecurity/ui/ExceptionTranslationFilterTests.java index 7e5372dc15..ca5a6adada 100644 --- a/core/src/test/java/org/acegisecurity/ui/ExceptionTranslationFilterTests.java +++ b/core/src/test/java/org/acegisecurity/ui/ExceptionTranslationFilterTests.java @@ -28,8 +28,6 @@ import org.acegisecurity.context.SecurityContextHolder; import org.acegisecurity.providers.anonymous.AnonymousAuthenticationToken; -import org.acegisecurity.ui.webapp.AuthenticationProcessingFilter; - import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; @@ -101,8 +99,7 @@ public class ExceptionTranslationFilterTests extends TestCase { filter.doFilter(request, response, chain); assertEquals("/mycontext/login.jsp", response.getRedirectedUrl()); assertEquals("http://www.example.com/mycontext/secure/page.html", - request.getSession() - .getAttribute(AuthenticationProcessingFilter.ACEGI_SECURITY_TARGET_URL_KEY)); + AbstractProcessingFilter.obtainFullRequestUrl(request)); } public void testAccessDeniedWhenNonAnonymous() throws Exception { @@ -192,8 +189,7 @@ public class ExceptionTranslationFilterTests extends TestCase { filter.doFilter(request, response, chain); assertEquals("/mycontext/login.jsp", response.getRedirectedUrl()); assertEquals("http://www.example.com/mycontext/secure/page.html", - request.getSession() - .getAttribute(AuthenticationProcessingFilter.ACEGI_SECURITY_TARGET_URL_KEY)); + AbstractProcessingFilter.obtainFullRequestUrl(request)); } public void testRedirectedToLoginFormAndSessionShowsOriginalTargetWithExoticPortWhenAuthenticationException() @@ -221,8 +217,7 @@ public class ExceptionTranslationFilterTests extends TestCase { filter.doFilter(request, response, chain); assertEquals("/mycontext/login.jsp", response.getRedirectedUrl()); assertEquals("http://www.example.com:8080/mycontext/secure/page.html", - request.getSession() - .getAttribute(AuthenticationProcessingFilter.ACEGI_SECURITY_TARGET_URL_KEY)); + AbstractProcessingFilter.obtainFullRequestUrl(request)); } public void testStartupDetectsMissingAuthenticationEntryPoint() diff --git a/core/src/test/java/org/acegisecurity/wrapper/SecurityContextHolderAwareRequestFilterTests.java b/core/src/test/java/org/acegisecurity/wrapper/SecurityContextHolderAwareRequestFilterTests.java index 6f7928752e..40c6f7a68a 100644 --- a/core/src/test/java/org/acegisecurity/wrapper/SecurityContextHolderAwareRequestFilterTests.java +++ b/core/src/test/java/org/acegisecurity/wrapper/SecurityContextHolderAwareRequestFilterTests.java @@ -1,4 +1,4 @@ -/* Copyright 2004, 2005 Acegi Technology Pty Limited +/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,8 @@ import junit.framework.TestCase; import org.acegisecurity.MockFilterConfig; +import org.springframework.mock.web.MockHttpServletRequest; + import java.io.IOException; import javax.servlet.FilterChain; @@ -26,8 +28,6 @@ import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; -import org.springframework.mock.web.MockHttpServletRequest; - /** * Tests {@link SecurityContextHolderAwareRequestFilter}. @@ -48,23 +48,23 @@ public class SecurityContextHolderAwareRequestFilterTests extends TestCase { //~ Methods ================================================================ - public final void setUp() throws Exception { - super.setUp(); - } - public static void main(String[] args) { junit.textui.TestRunner.run(SecurityContextHolderAwareRequestFilterTests.class); } + public final void setUp() throws Exception { + super.setUp(); + } + public void testCorrectOperation() throws Exception { SecurityContextHolderAwareRequestFilter filter = new SecurityContextHolderAwareRequestFilter(); filter.init(new MockFilterConfig()); filter.doFilter(new MockHttpServletRequest(null, null), null, - new MockFilterChain(SecurityContextHolderAwareRequestWrapper.class)); + new MockFilterChain(SavedRequestAwareWrapper.class)); // Now re-execute the filter, ensuring our replacement wrapper is still used filter.doFilter(new MockHttpServletRequest(null, null), null, - new MockFilterChain(SecurityContextHolderAwareRequestWrapper.class)); + new MockFilterChain(SavedRequestAwareWrapper.class)); filter.destroy(); } diff --git a/samples/contacts/src/main/webapp/filter/WEB-INF/applicationContext-acegi-security.xml b/samples/contacts/src/main/webapp/filter/WEB-INF/applicationContext-acegi-security.xml index 373316c057..38f61938e1 100644 --- a/samples/contacts/src/main/webapp/filter/WEB-INF/applicationContext-acegi-security.xml +++ b/samples/contacts/src/main/webapp/filter/WEB-INF/applicationContext-acegi-security.xml @@ -21,7 +21,7 @@ CONVERT_URL_TO_LOWERCASE_BEFORE_COMPARISON PATTERN_TYPE_APACHE_ANT - /**=httpSessionContextIntegrationFilter,logoutFilter,authenticationProcessingFilter,basicProcessingFilter,rememberMeProcessingFilter,anonymousProcessingFilter,switchUserProcessingFilter,exceptionTranslationFilter,filterInvocationInterceptor + /**=httpSessionContextIntegrationFilter,logoutFilter,authenticationProcessingFilter,basicProcessingFilter,securityContextHolderAwareRequestFilter,rememberMeProcessingFilter,anonymousProcessingFilter,switchUserProcessingFilter,exceptionTranslationFilter,filterInvocationInterceptor @@ -112,6 +112,8 @@ + +