From 6b24637fbc0eead728a99b34b607149cdbab4275 Mon Sep 17 00:00:00 2001 From: Luke Taylor Date: Fri, 21 Nov 2008 12:17:43 +0000 Subject: [PATCH] Further SavedRequestWrapper related tests and tidying up. --- .../security/ui/savedrequest/Enumerator.java | 30 +-- .../ui/savedrequest/SavedRequest.java | 72 +++-- .../wrapper/SavedRequestAwareWrapper.java | 176 ++++++------ ...urityContextHolderAwareRequestWrapper.java | 3 + .../SavedRequestAwareWrapperTests.java | 254 +++++++++++++----- 5 files changed, 324 insertions(+), 211 deletions(-) diff --git a/core/src/main/java/org/springframework/security/ui/savedrequest/Enumerator.java b/core/src/main/java/org/springframework/security/ui/savedrequest/Enumerator.java index e3ea12f3d5..71ee6030be 100644 --- a/core/src/main/java/org/springframework/security/ui/savedrequest/Enumerator.java +++ b/core/src/main/java/org/springframework/security/ui/savedrequest/Enumerator.java @@ -33,44 +33,42 @@ import java.util.NoSuchElementException; * @author Andrey Grebnev * @version $Id$ */ -@SuppressWarnings("unchecked") -public class Enumerator implements Enumeration { +public class Enumerator implements Enumeration { //~ Instance fields ================================================================================================ /** * The Iterator over which the Enumeration represented by this class actually operates. */ - private Iterator iterator = null; + private Iterator iterator = null; //~ Constructors =================================================================================================== -/** + /** * Return an Enumeration over the values of the specified Collection. * * @param collection Collection whose values should be enumerated */ - public Enumerator(Collection collection) { + public Enumerator(Collection collection) { this(collection.iterator()); } -/** + /** * Return an Enumeration over the values of the specified Collection. * * @param collection Collection whose values should be enumerated * @param clone true to clone iterator */ - public Enumerator(Collection collection, boolean clone) { + public Enumerator(Collection collection, boolean clone) { this(collection.iterator(), clone); } -/** + /** * Return an Enumeration over the values returned by the specified * Iterator. * * @param iterator Iterator to be wrapped */ - public Enumerator(Iterator iterator) { - super(); + public Enumerator(Iterator iterator) { this.iterator = iterator; } @@ -81,12 +79,12 @@ public class Enumerator implements Enumeration { * @param iterator Iterator to be wrapped * @param clone true to clone iterator */ - public Enumerator(Iterator iterator, boolean clone) { + public Enumerator(Iterator iterator, boolean clone) { if (!clone) { this.iterator = iterator; } else { - List list = new ArrayList(); + List list = new ArrayList(); while (iterator.hasNext()) { list.add(iterator.next()); @@ -101,17 +99,17 @@ public class Enumerator implements Enumeration { * * @param map Map whose values should be enumerated */ - public Enumerator(Map map) { + public Enumerator(Map map) { this(map.values().iterator()); } -/** + /** * Return an Enumeration over the values of the specified Map. * * @param map Map whose values should be enumerated * @param clone true to clone iterator */ - public Enumerator(Map map, boolean clone) { + public Enumerator(Map map, boolean clone) { this(map.values().iterator(), clone); } @@ -135,7 +133,7 @@ public class Enumerator implements Enumeration { * * @exception NoSuchElementException if no more elements exist */ - public Object nextElement() throws NoSuchElementException { + public T nextElement() throws NoSuchElementException { return (iterator.next()); } } diff --git a/core/src/main/java/org/springframework/security/ui/savedrequest/SavedRequest.java b/core/src/main/java/org/springframework/security/ui/savedrequest/SavedRequest.java index c359b559a4..3c2f5a8491 100644 --- a/core/src/main/java/org/springframework/security/ui/savedrequest/SavedRequest.java +++ b/core/src/main/java/org/springframework/security/ui/savedrequest/SavedRequest.java @@ -24,6 +24,7 @@ import org.springframework.util.Assert; import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletRequest; import java.util.ArrayList; +import java.util.Collections; import java.util.Enumeration; import java.util.Iterator; import java.util.List; @@ -46,7 +47,6 @@ import java.util.TreeMap; * @author Ben Alex * @version $Id$ */ -@SuppressWarnings("unchecked") public class SavedRequest implements java.io.Serializable { //~ Static fields/initializers ===================================================================================== @@ -54,10 +54,10 @@ public class SavedRequest implements java.io.Serializable { //~ Instance fields ================================================================================================ - private ArrayList cookies = new ArrayList(); - private ArrayList locales = new ArrayList(); - private Map headers = new TreeMap(String.CASE_INSENSITIVE_ORDER); - private Map parameters = new TreeMap(String.CASE_INSENSITIVE_ORDER); + private ArrayList cookies = new ArrayList(); + private ArrayList locales = new ArrayList(); + private Map> headers = new TreeMap>(String.CASE_INSENSITIVE_ORDER); + private Map parameters = new TreeMap(String.CASE_INSENSITIVE_ORDER); private String contextPath; private String method; private String pathInfo; @@ -71,6 +71,7 @@ public class SavedRequest implements java.io.Serializable { //~ Constructors =================================================================================================== + @SuppressWarnings("unchecked") public SavedRequest(HttpServletRequest request, PortResolver portResolver) { Assert.notNull(request, "Request required"); Assert.notNull(portResolver, "PortResolver required"); @@ -85,20 +86,19 @@ public class SavedRequest implements java.io.Serializable { } // Headers - Enumeration names = request.getHeaderNames(); + Enumeration names = request.getHeaderNames(); while (names.hasMoreElements()) { - String name = (String) names.nextElement(); - Enumeration values = request.getHeaders(name); + String name = names.nextElement(); + Enumeration values = request.getHeaders(name); while (values.hasMoreElements()) { - String value = (String) values.nextElement(); - this.addHeader(name, value); + this.addHeader(name, values.nextElement()); } } // Locales - Enumeration locales = request.getLocales(); + Enumeration locales = request.getLocales(); while (locales.hasMoreElements()) { Locale locale = (Locale) locales.nextElement(); @@ -106,15 +106,12 @@ public class SavedRequest implements java.io.Serializable { } // Parameters - Map parameters = request.getParameterMap(); - Iterator paramNames = parameters.keySet().iterator(); + Map parameters = request.getParameterMap(); - while (paramNames.hasNext()) { - String paramName = (String) paramNames.next(); - Object o = parameters.get(paramName); - if (o instanceof String[]) { - String[] paramValues = (String[]) o; - this.addParameter(paramName, paramValues); + for(String paramName : parameters.keySet()) { + Object paramValues = parameters.get(paramName); + if (paramValues instanceof String[]) { + this.addParameter(paramName, (String[]) paramValues); } else { if (logger.isWarnEnabled()) { logger.warn("ServletRequest.getParameterMap() returned non-String array"); @@ -142,10 +139,10 @@ public class SavedRequest implements java.io.Serializable { } private void addHeader(String name, String value) { - ArrayList values = (ArrayList) headers.get(name); + List values = headers.get(name); if (values == null) { - values = new ArrayList(); + values = new ArrayList(); headers.put(name, values); } @@ -163,10 +160,6 @@ public class SavedRequest implements java.io.Serializable { /** * Determines if the current request matches the SavedRequest. All URL arguments are * considered, but not method (POST/GET), cookies, locales, headers or parameters. - * - * @param request DOCUMENT ME! - * @param portResolver DOCUMENT ME! - * @return DOCUMENT ME! */ public boolean doesRequestMatch(HttpServletRequest request, PortResolver portResolver) { Assert.notNull(request, "Request required"); @@ -216,12 +209,13 @@ public class SavedRequest implements java.io.Serializable { return contextPath; } - public List getCookies() { - List cookieList = new ArrayList(cookies.size()); - for (Iterator iterator = cookies.iterator(); iterator.hasNext();) { - SavedCookie savedCookie = (SavedCookie) iterator.next(); + public List getCookies() { + List cookieList = new ArrayList(cookies.size()); + + for (SavedCookie savedCookie : cookies) { cookieList.add(savedCookie.getCookie()); } + return cookieList; } @@ -234,33 +228,33 @@ public class SavedRequest implements java.io.Serializable { return UrlUtils.getFullRequestUrl(this); } - public Iterator getHeaderNames() { + public Iterator getHeaderNames() { return (headers.keySet().iterator()); } - public Iterator getHeaderValues(String name) { - ArrayList values = (ArrayList) headers.get(name); + public Iterator getHeaderValues(String name) { + List values = headers.get(name); if (values == null) { - return ((new ArrayList()).iterator()); - } else { - return (values.iterator()); + values = Collections.emptyList(); } + + return (values.iterator()); } - public Iterator getLocales() { + public Iterator getLocales() { return (locales.iterator()); } public String getMethod() { - return (this.method); + return method; } - public Map getParameterMap() { + public Map getParameterMap() { return parameters; } - public Iterator getParameterNames() { + public Iterator getParameterNames() { return (parameters.keySet().iterator()); } diff --git a/core/src/main/java/org/springframework/security/wrapper/SavedRequestAwareWrapper.java b/core/src/main/java/org/springframework/security/wrapper/SavedRequestAwareWrapper.java index a88ade4ffa..fce84e952e 100644 --- a/core/src/main/java/org/springframework/security/wrapper/SavedRequestAwareWrapper.java +++ b/core/src/main/java/org/springframework/security/wrapper/SavedRequestAwareWrapper.java @@ -15,18 +15,7 @@ package org.springframework.security.wrapper; -import org.springframework.security.ui.AbstractProcessingFilter; -import org.springframework.security.ui.savedrequest.Enumerator; -import org.springframework.security.ui.savedrequest.FastHttpDateFormat; -import org.springframework.security.ui.savedrequest.SavedRequest; - -import org.springframework.security.util.PortResolver; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; - import java.text.SimpleDateFormat; - import java.util.ArrayList; import java.util.Arrays; import java.util.Enumeration; @@ -38,12 +27,19 @@ import java.util.Locale; import java.util.Map; import java.util.Set; import java.util.TimeZone; -import java.util.Map.Entry; import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpSession; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.security.ui.AbstractProcessingFilter; +import org.springframework.security.ui.savedrequest.Enumerator; +import org.springframework.security.ui.savedrequest.FastHttpDateFormat; +import org.springframework.security.ui.savedrequest.SavedRequest; +import org.springframework.security.util.PortResolver; + /** * Provides request parameters, headers and cookies from either an original request or a saved request. @@ -121,16 +117,18 @@ public class SavedRequestAwareWrapper extends SecurityContextHolderAwareRequestW //~ Methods ======================================================================================================== + @Override public Cookie[] getCookies() { if (savedRequest == null) { return super.getCookies(); } else { - List cookies = savedRequest.getCookies(); + List cookies = savedRequest.getCookies(); - return (Cookie[]) cookies.toArray(new Cookie[cookies.size()]); + return cookies.toArray(new Cookie[cookies.size()]); } } + @Override public long getDateHeader(String name) { if (savedRequest == null) { return super.getDateHeader(name); @@ -152,15 +150,16 @@ public class SavedRequestAwareWrapper extends SecurityContextHolderAwareRequestW } } + @Override public String getHeader(String name) { if (savedRequest == null) { return super.getHeader(name); } else { String header = null; - Iterator iterator = savedRequest.getHeaderValues(name); + Iterator iterator = savedRequest.getHeaderValues(name); while (iterator.hasNext()) { - header = (String) iterator.next(); + header = iterator.next(); break; } @@ -169,22 +168,25 @@ public class SavedRequestAwareWrapper extends SecurityContextHolderAwareRequestW } } + @Override public Enumeration getHeaderNames() { if (savedRequest == null) { return super.getHeaderNames(); } else { - return new Enumerator(savedRequest.getHeaderNames()); + return new Enumerator(savedRequest.getHeaderNames()); } } + @Override public Enumeration getHeaders(String name) { if (savedRequest == null) { return super.getHeaders(name); } else { - return new Enumerator(savedRequest.getHeaderValues(name)); + return new Enumerator(savedRequest.getHeaderValues(name)); } } + @Override public int getIntHeader(String name) { if (savedRequest == null) { return super.getIntHeader(name); @@ -199,12 +201,13 @@ public class SavedRequestAwareWrapper extends SecurityContextHolderAwareRequestW } } + @Override public Locale getLocale() { if (savedRequest == null) { return super.getLocale(); } else { Locale locale = null; - Iterator iterator = savedRequest.getLocales(); + Iterator iterator = savedRequest.getLocales(); while (iterator.hasNext()) { locale = (Locale) iterator.next(); @@ -220,23 +223,25 @@ public class SavedRequestAwareWrapper extends SecurityContextHolderAwareRequestW } } + @Override public Enumeration getLocales() { if (savedRequest == null) { return super.getLocales(); - } else { - Iterator iterator = savedRequest.getLocales(); - - if (iterator.hasNext()) { - return new Enumerator(iterator); - } else { - ArrayList results = new ArrayList(); - results.add(defaultLocale); - - return new Enumerator(results.iterator()); - } } + + Iterator iterator = savedRequest.getLocales(); + + if (iterator.hasNext()) { + return new Enumerator(iterator); + } + // Fall back to default locale + ArrayList results = new ArrayList(1); + results.add(defaultLocale); + + return new Enumerator(results.iterator()); } + @Override public String getMethod() { if (savedRequest == null) { return super.getMethod(); @@ -257,83 +262,82 @@ public class SavedRequestAwareWrapper extends SecurityContextHolderAwareRequestW * If the value from the wrapped request is null, an attempt will be made to retrieve the parameter * from the SavedRequest, if available.. */ + @Override public String getParameter(String name) { String value = super.getParameter(name); - + if (value != null || savedRequest == null) { - return value; + return value; } String[] values = savedRequest.getParameterValues(name); - if (values == null) - return null; - for (int i = 0; i < values.length; i++) { - value = values[i]; - break; - } + if (values == null || values.length == 0) { + return null; + } - return value; + return values[0]; } + @Override public Map getParameterMap() { - if (savedRequest == null) { + if (savedRequest == null) { return super.getParameterMap(); } - - Set names = getCombinedParameterNames(); - Iterator nameIter = names.iterator(); - Map parameterMap = new HashMap(names.size()); - - while (nameIter.hasNext()) { - String name = (String) nameIter.next(); - parameterMap.put(name, getParameterValues(name)); - } - - return parameterMap; - } - - private Set getCombinedParameterNames() { - Set names = new HashSet(); - names.addAll(super.getParameterMap().keySet()); - - if (savedRequest != null) { - names.addAll(savedRequest.getParameterMap().keySet()); - } - - return names; + + Set names = getCombinedParameterNames(); + Map parameterMap = new HashMap(names.size()); + + for (String name : names) { + parameterMap.put(name, getParameterValues(name)); + } + + return parameterMap; } + private Set getCombinedParameterNames() { + Set names = new HashSet(); + names.addAll(super.getParameterMap().keySet()); + + if (savedRequest != null) { + names.addAll(savedRequest.getParameterMap().keySet()); + } + + return names; + } + + @Override public Enumeration getParameterNames() { - return new Enumerator(getCombinedParameterNames()); + return new Enumerator(getCombinedParameterNames()); } + @Override public String[] getParameterValues(String name) { - if (savedRequest == null) { - return super.getParameterValues(name); - } - - String[] savedRequestParams = savedRequest.getParameterValues(name); - String[] wrappedRequestParams = super.getParameterValues(name); + if (savedRequest == null) { + return super.getParameterValues(name); + } - if (savedRequestParams == null) { - return wrappedRequestParams; - } - - if (wrappedRequestParams == null) { - return savedRequestParams; - } + String[] savedRequestParams = savedRequest.getParameterValues(name); + String[] wrappedRequestParams = super.getParameterValues(name); - // We have params in both saved and wrapped requests so have to merge them - List wrappedParamsList = Arrays.asList(wrappedRequestParams); - List combinedParams = new ArrayList(wrappedParamsList); + if (savedRequestParams == null) { + return wrappedRequestParams; + } - // We want to add all parameters of the saved request *apart from* duplicates of those already added - for (int i = 0; i < savedRequestParams.length; i++) { - if (!wrappedParamsList.contains(savedRequestParams[i])) { - combinedParams.add(savedRequestParams[i]); - } - } + if (wrappedRequestParams == null) { + return savedRequestParams; + } - return (String[]) combinedParams.toArray(new String[combinedParams.size()]); + // We have parameters in both saved and wrapped requests so have to merge them + List wrappedParamsList = Arrays.asList(wrappedRequestParams); + List combinedParams = new ArrayList(wrappedParamsList); + + // We want to add all parameters of the saved request *apart from* duplicates of those already added + for (int i = 0; i < savedRequestParams.length; i++) { + if (!wrappedParamsList.contains(savedRequestParams[i])) { + combinedParams.add(savedRequestParams[i]); + } + } + + return combinedParams.toArray(new String[combinedParams.size()]); } } diff --git a/core/src/main/java/org/springframework/security/wrapper/SecurityContextHolderAwareRequestWrapper.java b/core/src/main/java/org/springframework/security/wrapper/SecurityContextHolderAwareRequestWrapper.java index d651c831c6..4244fb2c47 100644 --- a/core/src/main/java/org/springframework/security/wrapper/SecurityContextHolderAwareRequestWrapper.java +++ b/core/src/main/java/org/springframework/security/wrapper/SecurityContextHolderAwareRequestWrapper.java @@ -89,6 +89,7 @@ public class SecurityContextHolderAwareRequestWrapper extends HttpServletRequest * * @return the username or null if unavailable */ + @Override public String getRemoteUser() { Authentication auth = getAuthentication(); @@ -109,6 +110,7 @@ public class SecurityContextHolderAwareRequestWrapper extends HttpServletRequest * * @return the Authentication, or null */ + @Override public Principal getUserPrincipal() { Authentication auth = getAuthentication(); @@ -158,6 +160,7 @@ public class SecurityContextHolderAwareRequestWrapper extends HttpServletRequest * @return true if an exact (case sensitive) matching granted authority is located, * false otherwise */ + @Override public boolean isUserInRole(String role) { return isGranted(role); } diff --git a/core/src/test/java/org/springframework/security/wrapper/SavedRequestAwareWrapperTests.java b/core/src/test/java/org/springframework/security/wrapper/SavedRequestAwareWrapperTests.java index 97cd35930b..8b0a1ac5e7 100644 --- a/core/src/test/java/org/springframework/security/wrapper/SavedRequestAwareWrapperTests.java +++ b/core/src/test/java/org/springframework/security/wrapper/SavedRequestAwareWrapperTests.java @@ -2,86 +2,200 @@ package org.springframework.security.wrapper; import static org.junit.Assert.*; +import java.text.SimpleDateFormat; +import java.util.Date; +import java.util.Enumeration; +import java.util.Locale; + +import javax.servlet.http.Cookie; + import org.junit.Test; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.ui.AbstractProcessingFilter; +import org.springframework.security.ui.savedrequest.FastHttpDateFormat; import org.springframework.security.ui.savedrequest.SavedRequest; import org.springframework.security.util.PortResolverImpl; public class SavedRequestAwareWrapperTests { - - @Test - /* SEC-830. Assume we have a request to /someUrl?action=foo (the saved request) - * and then RequestDispatcher.forward() it to /someUrl?action=bar. - * What should action parameter be before and during the forward? - **/ - public void wrappedRequestParameterTakesPrecedenceOverSavedRequest() { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setParameter("action", "foo"); - SavedRequest savedRequest = new SavedRequest(request, new PortResolverImpl()); - MockHttpServletRequest request2 = new MockHttpServletRequest(); - request2.getSession().setAttribute(AbstractProcessingFilter.SPRING_SECURITY_SAVED_REQUEST_KEY, savedRequest); - SavedRequestAwareWrapper wrapper = new SavedRequestAwareWrapper(request2, new PortResolverImpl(), "ROLE_"); - assertEquals("foo", wrapper.getParameter("action")); - // The request after forward - request2.setParameter("action", "bar"); - assertEquals("bar", wrapper.getParameter("action")); - // Both values should be set, but "bar" should be first - assertEquals(2, wrapper.getParameterValues("action").length); - assertEquals("bar", wrapper.getParameterValues("action")[0]); - } - @Test - public void savedRequestDoesntCreateDuplicateParams() { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setParameter("action", "foo"); - SavedRequest savedRequest = new SavedRequest(request, new PortResolverImpl()); - MockHttpServletRequest request2 = new MockHttpServletRequest(); - request2.getSession().setAttribute(AbstractProcessingFilter.SPRING_SECURITY_SAVED_REQUEST_KEY, savedRequest); - request2.setParameter("action", "foo"); - SavedRequestAwareWrapper wrapper = new SavedRequestAwareWrapper(request2, new PortResolverImpl(), "ROLE_"); - assertEquals(1, wrapper.getParameterValues("action").length); - assertEquals(1, wrapper.getParameterMap().size()); - assertEquals(1, ((String[])wrapper.getParameterMap().get("action")).length); - } - - @Test - public void savedRequestHeadersTakePrecedence() { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.addHeader("Authorization","foo"); - SavedRequest savedRequest = new SavedRequest(request, new PortResolverImpl()); + private SavedRequestAwareWrapper createWrapper(MockHttpServletRequest requestToSave, MockHttpServletRequest requestToWrap) { + if (requestToSave != null) { + SavedRequest savedRequest = new SavedRequest(requestToSave, new PortResolverImpl()); + requestToWrap.getSession().setAttribute(AbstractProcessingFilter.SPRING_SECURITY_SAVED_REQUEST_KEY, savedRequest); + } + return new SavedRequestAwareWrapper(requestToWrap, new PortResolverImpl(),"ROLE_"); + } - MockHttpServletRequest request2 = new MockHttpServletRequest(); - request2.addHeader("Authorization","bar"); - request2.getSession().setAttribute(AbstractProcessingFilter.SPRING_SECURITY_SAVED_REQUEST_KEY, savedRequest); + @Test + public void wrappedRequestCookiesAreReturnedIfNoSavedRequestIsSet() throws Exception { + MockHttpServletRequest wrappedRequest = new MockHttpServletRequest(); + wrappedRequest.setCookies(new Cookie[] {new Cookie("cookie", "fromwrapped")}); + SavedRequestAwareWrapper wrapper = createWrapper(null, wrappedRequest); + assertEquals(1, wrapper.getCookies().length); + assertEquals("fromwrapped", wrapper.getCookies()[0].getValue()); + } - SavedRequestAwareWrapper wrapper = new SavedRequestAwareWrapper(request2, new PortResolverImpl(), "ROLE_"); + @Test + public void savedRequestCookiesAreReturnedIfSavedRequestIsSet() throws Exception { + MockHttpServletRequest savedRequest = new MockHttpServletRequest(); + savedRequest.setCookies(new Cookie[] {new Cookie("cookie", "fromsaved")}); + SavedRequestAwareWrapper wrapper = createWrapper(savedRequest, new MockHttpServletRequest()); + assertEquals(1, wrapper.getCookies().length); + assertEquals("fromsaved", wrapper.getCookies()[0].getValue()); + } - assertEquals("foo", wrapper.getHeader("Authorization")); - } + @Test + public void savedRequesthHeaderIsReturnedIfSavedRequestIsSet() throws Exception { + MockHttpServletRequest savedRequest = new MockHttpServletRequest(); + savedRequest.addHeader("header", "savedheader"); + SavedRequestAwareWrapper wrapper = createWrapper(savedRequest, new MockHttpServletRequest()); + + assertNull(wrapper.getHeader("nonexistent")); + Enumeration headers = wrapper.getHeaders("nonexistent"); + assertFalse(headers.hasMoreElements()); + + assertEquals("savedheader", wrapper.getHeader("header")); + headers = wrapper.getHeaders("header"); + assertTrue(headers.hasMoreElements()); + assertEquals("savedheader", headers.nextElement()); + assertFalse(headers.hasMoreElements()); + assertTrue(wrapper.getHeaderNames().hasMoreElements()); + assertEquals("header", wrapper.getHeaderNames().nextElement()); + } + + @Test + public void wrappedRequestHeaderIsReturnedIfSavedRequestIsNotSet() throws Exception { + MockHttpServletRequest wrappedRequest = new MockHttpServletRequest(); + wrappedRequest.addHeader("header", "wrappedheader"); + SavedRequestAwareWrapper wrapper = createWrapper(null, wrappedRequest); + + assertNull(wrapper.getHeader("nonexistent")); + Enumeration headers = wrapper.getHeaders("nonexistent"); + assertFalse(headers.hasMoreElements()); + + assertEquals("wrappedheader", wrapper.getHeader("header")); + headers = wrapper.getHeaders("header"); + assertTrue(headers.hasMoreElements()); + assertEquals("wrappedheader", headers.nextElement()); + assertFalse(headers.hasMoreElements()); + assertTrue(wrapper.getHeaderNames().hasMoreElements()); + assertEquals("header", wrapper.getHeaderNames().nextElement()); + } + + + @Test + /* SEC-830. Assume we have a request to /someUrl?action=foo (the saved request) + * and then RequestDispatcher.forward() it to /someUrl?action=bar. + * What should action parameter be before and during the forward? + **/ + public void wrappedRequestParameterTakesPrecedenceOverSavedRequest() { + MockHttpServletRequest savedRequest = new MockHttpServletRequest(); + savedRequest.setParameter("action", "foo"); + MockHttpServletRequest wrappedRequest = new MockHttpServletRequest(); + SavedRequestAwareWrapper wrapper = createWrapper(savedRequest, wrappedRequest); + assertEquals("foo", wrapper.getParameter("action")); + // The request after forward + wrappedRequest.setParameter("action", "bar"); + assertEquals("bar", wrapper.getParameter("action")); + // Both values should be set, but "bar" should be first + assertEquals(2, wrapper.getParameterValues("action").length); + assertEquals("bar", wrapper.getParameterValues("action")[0]); + } + + @Test + public void savedRequestDoesntCreateDuplicateParams() { + MockHttpServletRequest savedRequest = new MockHttpServletRequest(); + savedRequest.setParameter("action", "foo"); + MockHttpServletRequest wrappedRequest = new MockHttpServletRequest(); + wrappedRequest.setParameter("action", "foo"); + SavedRequestAwareWrapper wrapper = createWrapper(savedRequest, wrappedRequest); + assertEquals(1, wrapper.getParameterValues("action").length); + assertEquals(1, wrapper.getParameterMap().size()); + assertEquals(1, ((String[])wrapper.getParameterMap().get("action")).length); + } + + @Test + public void savedRequestHeadersTakePrecedence() { + MockHttpServletRequest savedRequest = new MockHttpServletRequest(); + savedRequest.addHeader("Authorization","foo"); + MockHttpServletRequest wrappedRequest = new MockHttpServletRequest(); + wrappedRequest.addHeader("Authorization","bar"); + SavedRequestAwareWrapper wrapper = createWrapper(savedRequest, wrappedRequest); + assertEquals("foo", wrapper.getHeader("Authorization")); + } + + @Test + public void getParameterValuesReturnsNullIfParameterIsntSet() { + MockHttpServletRequest wrappedRequest = new MockHttpServletRequest(); + SavedRequestAwareWrapper wrapper = new SavedRequestAwareWrapper(wrappedRequest, new PortResolverImpl(), "ROLE_"); + assertNull(wrapper.getParameterValues("action")); + assertNull(wrapper.getParameterMap().get("action")); + } + + @Test + public void getParameterValuesReturnsCombinedSavedAndWrappedRequestValues() { + MockHttpServletRequest savedRequest = new MockHttpServletRequest(); + savedRequest.setParameter("action", "foo"); + MockHttpServletRequest wrappedRequest = new MockHttpServletRequest(); + SavedRequestAwareWrapper wrapper = createWrapper(savedRequest, wrappedRequest); + + assertArrayEquals(new Object[] {"foo"}, wrapper.getParameterValues("action")); + wrappedRequest.setParameter("action", "bar"); + assertArrayEquals(new Object[] {"bar","foo"}, wrapper.getParameterValues("action")); + // Check map is consistent + String[] valuesFromMap = (String[]) wrapper.getParameterMap().get("action"); + assertEquals(2, valuesFromMap.length); + assertEquals("bar", valuesFromMap[0]); + } + + @Test + public void expecteDateHeaderIsReturnedFromSavedAndWrappedRequests() throws Exception { + SimpleDateFormat formatter = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US); + String nowString = FastHttpDateFormat.getCurrentDate(); + Date now = formatter.parse(nowString); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("header", nowString); + SavedRequestAwareWrapper wrapper = createWrapper(request, new MockHttpServletRequest()); + assertEquals(now.getTime(), wrapper.getDateHeader("header")); + + assertEquals(-1L, wrapper.getDateHeader("nonexistent")); + + // Now try with no saved request + request = new MockHttpServletRequest(); + request.addHeader("header", now); + wrapper = createWrapper(null, request); + assertEquals(now.getTime(), wrapper.getDateHeader("header")); + } + + @Test(expected=IllegalArgumentException.class) + public void invalidDateHeaderIsRejected() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("header", "notadate"); + SavedRequestAwareWrapper wrapper = createWrapper(request, new MockHttpServletRequest()); + wrapper.getDateHeader("header"); + } + + @Test + public void correctHttpMethodIsReturned() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest("PUT", "/notused"); + SavedRequestAwareWrapper wrapper = createWrapper(request, new MockHttpServletRequest("GET", "/notused")); + assertEquals("PUT", wrapper.getMethod()); + wrapper = createWrapper(null, request); + assertEquals("PUT", wrapper.getMethod()); + } + + @Test + public void correctIntHeaderIsReturned() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("header", "999"); + request.addHeader("header", "1000"); + SavedRequestAwareWrapper wrapper = createWrapper(request, new MockHttpServletRequest()); + + assertEquals(999, wrapper.getIntHeader("header")); + assertEquals(-1, wrapper.getIntHeader("nonexistent")); + + wrapper = createWrapper(null, request); + assertEquals(999, wrapper.getIntHeader("header")); + } - @Test - public void getParameterValuesReturnsNullIfParameterIsntSet() { - MockHttpServletRequest request = new MockHttpServletRequest(); - SavedRequestAwareWrapper wrapper = new SavedRequestAwareWrapper(request, new PortResolverImpl(), "ROLE_"); - assertNull(wrapper.getParameterValues("action")); - assertNull(wrapper.getParameterMap().get("action")); - } - - @Test - public void getParameterValuesReturnsCombinedValues() { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setParameter("action", "foo"); - SavedRequest savedRequest = new SavedRequest(request, new PortResolverImpl()); - MockHttpServletRequest request2 = new MockHttpServletRequest(); - request2.getSession().setAttribute(AbstractProcessingFilter.SPRING_SECURITY_SAVED_REQUEST_KEY, savedRequest); - SavedRequestAwareWrapper wrapper = new SavedRequestAwareWrapper(request2, new PortResolverImpl(), "ROLE_"); - assertArrayEquals(new Object[] {"foo"}, wrapper.getParameterValues("action")); - request2.setParameter("action", "bar"); - assertArrayEquals(new Object[] {"bar","foo"}, wrapper.getParameterValues("action")); - // Check map is consistent - String[] valuesFromMap = (String[]) wrapper.getParameterMap().get("action"); - assertEquals(2, valuesFromMap.length); - assertEquals("bar", valuesFromMap[0]); - } }