diff --git a/core/src/main/java/org/springframework/security/wrapper/SavedRequestAwareWrapper.java b/core/src/main/java/org/springframework/security/wrapper/SavedRequestAwareWrapper.java index 698fbe61a0..a88ade4ffa 100644 --- a/core/src/main/java/org/springframework/security/wrapper/SavedRequestAwareWrapper.java +++ b/core/src/main/java/org/springframework/security/wrapper/SavedRequestAwareWrapper.java @@ -31,10 +31,12 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Enumeration; import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Set; import java.util.TimeZone; import java.util.Map.Entry; @@ -274,58 +276,62 @@ public class SavedRequestAwareWrapper extends SecurityContextHolderAwareRequestW } public Map getParameterMap() { - Map parameters = super.getParameterMap(); - if (savedRequest == null) { - return parameters; + return super.getParameterMap(); } - // We have a saved request so merge the values, with the wrapped request taking precedence (see getParameter()) - Map newParameters = new HashMap(savedRequest.getParameterMap().size() + parameters.size()); - newParameters.putAll(parameters); + Set names = getCombinedParameterNames(); + Iterator nameIter = names.iterator(); + Map parameterMap = new HashMap(names.size()); - Iterator savedParams = savedRequest.getParameterMap().entrySet().iterator(); - - while (savedParams.hasNext()) { - Map.Entry entry = (Entry) savedParams.next(); - String name = (String) entry.getKey(); - String[] savedParamValues = (String[]) entry.getValue(); - - if (newParameters.containsKey(name)) { - // merge values - String[] existingValues = (String[]) newParameters.get(name); - String[] mergedValues = new String[savedParamValues.length + existingValues.length]; - System.arraycopy(existingValues, 0, mergedValues, 0, existingValues.length); - System.arraycopy(savedParamValues, 0, mergedValues, existingValues.length, savedParamValues.length); - newParameters.put(name, mergedValues); - } else { - newParameters.put(name, savedParamValues); - } + while (nameIter.hasNext()) { + String name = (String) nameIter.next(); + parameterMap.put(name, getParameterValues(name)); } - - return newParameters; + + return parameterMap; + } + + private Set getCombinedParameterNames() { + Set names = new HashSet(); + names.addAll(super.getParameterMap().keySet()); + + if (savedRequest != null) { + names.addAll(savedRequest.getParameterMap().keySet()); + } + + return names; } public Enumeration getParameterNames() { - return new Enumerator(getParameterMap().keySet()); + return new Enumerator(getCombinedParameterNames()); } public String[] getParameterValues(String name) { - String[] savedRequestParams = savedRequest == null ? null : savedRequest.getParameterValues(name); + if (savedRequest == null) { + return super.getParameterValues(name); + } + + String[] savedRequestParams = savedRequest.getParameterValues(name); String[] wrappedRequestParams = super.getParameterValues(name); - if (savedRequestParams == null && wrappedRequestParams == null) { - return null; + if (savedRequestParams == null) { + return wrappedRequestParams; + } + + if (wrappedRequestParams == null) { + return savedRequestParams; } - List combinedParams = new ArrayList(); + // We have params in both saved and wrapped requests so have to merge them + List wrappedParamsList = Arrays.asList(wrappedRequestParams); + List combinedParams = new ArrayList(wrappedParamsList); - if (wrappedRequestParams != null) { - combinedParams.addAll(Arrays.asList(wrappedRequestParams)); - } - - if (savedRequestParams != null) { - combinedParams.addAll(Arrays.asList(savedRequestParams)); + // We want to add all parameters of the saved request *apart from* duplicates of those already added + for (int i = 0; i < savedRequestParams.length; i++) { + if (!wrappedParamsList.contains(savedRequestParams[i])) { + combinedParams.add(savedRequestParams[i]); + } } return (String[]) combinedParams.toArray(new String[combinedParams.size()]); diff --git a/core/src/test/java/org/springframework/security/wrapper/SavedRequestAwareWrapperTests.java b/core/src/test/java/org/springframework/security/wrapper/SavedRequestAwareWrapperTests.java index ba5b933c91..97cd35930b 100644 --- a/core/src/test/java/org/springframework/security/wrapper/SavedRequestAwareWrapperTests.java +++ b/core/src/test/java/org/springframework/security/wrapper/SavedRequestAwareWrapperTests.java @@ -11,7 +11,10 @@ import org.springframework.security.util.PortResolverImpl; public class SavedRequestAwareWrapperTests { @Test - /* SEC-830 */ + /* SEC-830. Assume we have a request to /someUrl?action=foo (the saved request) + * and then RequestDispatcher.forward() it to /someUrl?action=bar. + * What should action parameter be before and during the forward? + **/ public void wrappedRequestParameterTakesPrecedenceOverSavedRequest() { MockHttpServletRequest request = new MockHttpServletRequest(); request.setParameter("action", "foo"); @@ -20,8 +23,26 @@ public class SavedRequestAwareWrapperTests { request2.getSession().setAttribute(AbstractProcessingFilter.SPRING_SECURITY_SAVED_REQUEST_KEY, savedRequest); SavedRequestAwareWrapper wrapper = new SavedRequestAwareWrapper(request2, new PortResolverImpl(), "ROLE_"); assertEquals("foo", wrapper.getParameter("action")); + // The request after forward request2.setParameter("action", "bar"); assertEquals("bar", wrapper.getParameter("action")); + // Both values should be set, but "bar" should be first + assertEquals(2, wrapper.getParameterValues("action").length); + assertEquals("bar", wrapper.getParameterValues("action")[0]); + } + + @Test + public void savedRequestDoesntCreateDuplicateParams() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setParameter("action", "foo"); + SavedRequest savedRequest = new SavedRequest(request, new PortResolverImpl()); + MockHttpServletRequest request2 = new MockHttpServletRequest(); + request2.getSession().setAttribute(AbstractProcessingFilter.SPRING_SECURITY_SAVED_REQUEST_KEY, savedRequest); + request2.setParameter("action", "foo"); + SavedRequestAwareWrapper wrapper = new SavedRequestAwareWrapper(request2, new PortResolverImpl(), "ROLE_"); + assertEquals(1, wrapper.getParameterValues("action").length); + assertEquals(1, wrapper.getParameterMap().size()); + assertEquals(1, ((String[])wrapper.getParameterMap().get("action")).length); } @Test