Further SavedRequestWrapper related tests and tidying up.

This commit is contained in:
Luke Taylor 2008-11-21 12:17:43 +00:00
parent 7e562031cc
commit 6b24637fbc
5 changed files with 324 additions and 211 deletions

View File

@ -33,44 +33,42 @@ import java.util.NoSuchElementException;
* @author Andrey Grebnev * @author Andrey Grebnev
* @version $Id$ * @version $Id$
*/ */
@SuppressWarnings("unchecked") public class Enumerator<T> implements Enumeration<T> {
public class Enumerator implements Enumeration {
//~ Instance fields ================================================================================================ //~ Instance fields ================================================================================================
/** /**
* The <code>Iterator</code> over which the <code>Enumeration</code> represented by this class actually operates. * The <code>Iterator</code> over which the <code>Enumeration</code> represented by this class actually operates.
*/ */
private Iterator iterator = null; private Iterator<T> iterator = null;
//~ Constructors =================================================================================================== //~ Constructors ===================================================================================================
/** /**
* Return an Enumeration over the values of the specified Collection. * Return an Enumeration over the values of the specified Collection.
* *
* @param collection Collection whose values should be enumerated * @param collection Collection whose values should be enumerated
*/ */
public Enumerator(Collection collection) { public Enumerator(Collection<T> collection) {
this(collection.iterator()); this(collection.iterator());
} }
/** /**
* Return an Enumeration over the values of the specified Collection. * Return an Enumeration over the values of the specified Collection.
* *
* @param collection Collection whose values should be enumerated * @param collection Collection whose values should be enumerated
* @param clone true to clone iterator * @param clone true to clone iterator
*/ */
public Enumerator(Collection collection, boolean clone) { public Enumerator(Collection<T> collection, boolean clone) {
this(collection.iterator(), clone); this(collection.iterator(), clone);
} }
/** /**
* Return an Enumeration over the values returned by the specified * Return an Enumeration over the values returned by the specified
* Iterator. * Iterator.
* *
* @param iterator Iterator to be wrapped * @param iterator Iterator to be wrapped
*/ */
public Enumerator(Iterator iterator) { public Enumerator(Iterator<T> iterator) {
super();
this.iterator = iterator; this.iterator = iterator;
} }
@ -81,12 +79,12 @@ public class Enumerator implements Enumeration {
* @param iterator Iterator to be wrapped * @param iterator Iterator to be wrapped
* @param clone true to clone iterator * @param clone true to clone iterator
*/ */
public Enumerator(Iterator iterator, boolean clone) { public Enumerator(Iterator<T> iterator, boolean clone) {
if (!clone) { if (!clone) {
this.iterator = iterator; this.iterator = iterator;
} else { } else {
List list = new ArrayList(); List<T> list = new ArrayList<T>();
while (iterator.hasNext()) { while (iterator.hasNext()) {
list.add(iterator.next()); list.add(iterator.next());
@ -101,17 +99,17 @@ public class Enumerator implements Enumeration {
* *
* @param map Map whose values should be enumerated * @param map Map whose values should be enumerated
*/ */
public Enumerator(Map map) { public Enumerator(Map<?, T> map) {
this(map.values().iterator()); this(map.values().iterator());
} }
/** /**
* Return an Enumeration over the values of the specified Map. * Return an Enumeration over the values of the specified Map.
* *
* @param map Map whose values should be enumerated * @param map Map whose values should be enumerated
* @param clone true to clone iterator * @param clone true to clone iterator
*/ */
public Enumerator(Map map, boolean clone) { public Enumerator(Map<?, T> map, boolean clone) {
this(map.values().iterator(), clone); this(map.values().iterator(), clone);
} }
@ -135,7 +133,7 @@ public class Enumerator implements Enumeration {
* *
* @exception NoSuchElementException if no more elements exist * @exception NoSuchElementException if no more elements exist
*/ */
public Object nextElement() throws NoSuchElementException { public T nextElement() throws NoSuchElementException {
return (iterator.next()); return (iterator.next());
} }
} }

View File

@ -24,6 +24,7 @@ import org.springframework.util.Assert;
import javax.servlet.http.Cookie; import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration; import java.util.Enumeration;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
@ -46,7 +47,6 @@ import java.util.TreeMap;
* @author Ben Alex * @author Ben Alex
* @version $Id$ * @version $Id$
*/ */
@SuppressWarnings("unchecked")
public class SavedRequest implements java.io.Serializable { public class SavedRequest implements java.io.Serializable {
//~ Static fields/initializers ===================================================================================== //~ Static fields/initializers =====================================================================================
@ -54,10 +54,10 @@ public class SavedRequest implements java.io.Serializable {
//~ Instance fields ================================================================================================ //~ Instance fields ================================================================================================
private ArrayList cookies = new ArrayList(); private ArrayList<SavedCookie> cookies = new ArrayList<SavedCookie>();
private ArrayList locales = new ArrayList(); private ArrayList<Locale> locales = new ArrayList<Locale>();
private Map headers = new TreeMap(String.CASE_INSENSITIVE_ORDER); private Map<String, List<String>> headers = new TreeMap<String, List<String>>(String.CASE_INSENSITIVE_ORDER);
private Map parameters = new TreeMap(String.CASE_INSENSITIVE_ORDER); private Map<String, String[]> parameters = new TreeMap<String, String[]>(String.CASE_INSENSITIVE_ORDER);
private String contextPath; private String contextPath;
private String method; private String method;
private String pathInfo; private String pathInfo;
@ -71,6 +71,7 @@ public class SavedRequest implements java.io.Serializable {
//~ Constructors =================================================================================================== //~ Constructors ===================================================================================================
@SuppressWarnings("unchecked")
public SavedRequest(HttpServletRequest request, PortResolver portResolver) { public SavedRequest(HttpServletRequest request, PortResolver portResolver) {
Assert.notNull(request, "Request required"); Assert.notNull(request, "Request required");
Assert.notNull(portResolver, "PortResolver required"); Assert.notNull(portResolver, "PortResolver required");
@ -85,20 +86,19 @@ public class SavedRequest implements java.io.Serializable {
} }
// Headers // Headers
Enumeration names = request.getHeaderNames(); Enumeration<String> names = request.getHeaderNames();
while (names.hasMoreElements()) { while (names.hasMoreElements()) {
String name = (String) names.nextElement(); String name = names.nextElement();
Enumeration values = request.getHeaders(name); Enumeration<String> values = request.getHeaders(name);
while (values.hasMoreElements()) { while (values.hasMoreElements()) {
String value = (String) values.nextElement(); this.addHeader(name, values.nextElement());
this.addHeader(name, value);
} }
} }
// Locales // Locales
Enumeration locales = request.getLocales(); Enumeration<Locale> locales = request.getLocales();
while (locales.hasMoreElements()) { while (locales.hasMoreElements()) {
Locale locale = (Locale) locales.nextElement(); Locale locale = (Locale) locales.nextElement();
@ -106,15 +106,12 @@ public class SavedRequest implements java.io.Serializable {
} }
// Parameters // Parameters
Map parameters = request.getParameterMap(); Map<String,Object> parameters = request.getParameterMap();
Iterator paramNames = parameters.keySet().iterator();
while (paramNames.hasNext()) { for(String paramName : parameters.keySet()) {
String paramName = (String) paramNames.next(); Object paramValues = parameters.get(paramName);
Object o = parameters.get(paramName); if (paramValues instanceof String[]) {
if (o instanceof String[]) { this.addParameter(paramName, (String[]) paramValues);
String[] paramValues = (String[]) o;
this.addParameter(paramName, paramValues);
} else { } else {
if (logger.isWarnEnabled()) { if (logger.isWarnEnabled()) {
logger.warn("ServletRequest.getParameterMap() returned non-String array"); logger.warn("ServletRequest.getParameterMap() returned non-String array");
@ -142,10 +139,10 @@ public class SavedRequest implements java.io.Serializable {
} }
private void addHeader(String name, String value) { private void addHeader(String name, String value) {
ArrayList values = (ArrayList) headers.get(name); List<String> values = headers.get(name);
if (values == null) { if (values == null) {
values = new ArrayList(); values = new ArrayList<String>();
headers.put(name, values); headers.put(name, values);
} }
@ -163,10 +160,6 @@ public class SavedRequest implements java.io.Serializable {
/** /**
* Determines if the current request matches the <code>SavedRequest</code>. All URL arguments are * Determines if the current request matches the <code>SavedRequest</code>. All URL arguments are
* considered, but <em>not</em> method (POST/GET), cookies, locales, headers or parameters. * considered, but <em>not</em> method (POST/GET), cookies, locales, headers or parameters.
*
* @param request DOCUMENT ME!
* @param portResolver DOCUMENT ME!
* @return DOCUMENT ME!
*/ */
public boolean doesRequestMatch(HttpServletRequest request, PortResolver portResolver) { public boolean doesRequestMatch(HttpServletRequest request, PortResolver portResolver) {
Assert.notNull(request, "Request required"); Assert.notNull(request, "Request required");
@ -216,12 +209,13 @@ public class SavedRequest implements java.io.Serializable {
return contextPath; return contextPath;
} }
public List getCookies() { public List<Cookie> getCookies() {
List cookieList = new ArrayList(cookies.size()); List<Cookie> cookieList = new ArrayList<Cookie>(cookies.size());
for (Iterator iterator = cookies.iterator(); iterator.hasNext();) {
SavedCookie savedCookie = (SavedCookie) iterator.next(); for (SavedCookie savedCookie : cookies) {
cookieList.add(savedCookie.getCookie()); cookieList.add(savedCookie.getCookie());
} }
return cookieList; return cookieList;
} }
@ -234,33 +228,33 @@ public class SavedRequest implements java.io.Serializable {
return UrlUtils.getFullRequestUrl(this); return UrlUtils.getFullRequestUrl(this);
} }
public Iterator getHeaderNames() { public Iterator<String> getHeaderNames() {
return (headers.keySet().iterator()); return (headers.keySet().iterator());
} }
public Iterator getHeaderValues(String name) { public Iterator<String> getHeaderValues(String name) {
ArrayList values = (ArrayList) headers.get(name); List<String> values = headers.get(name);
if (values == null) { if (values == null) {
return ((new ArrayList()).iterator()); values = Collections.emptyList();
} else {
return (values.iterator());
} }
return (values.iterator());
} }
public Iterator getLocales() { public Iterator<Locale> getLocales() {
return (locales.iterator()); return (locales.iterator());
} }
public String getMethod() { public String getMethod() {
return (this.method); return method;
} }
public Map getParameterMap() { public Map<String, String[]> getParameterMap() {
return parameters; return parameters;
} }
public Iterator getParameterNames() { public Iterator<String> getParameterNames() {
return (parameters.keySet().iterator()); return (parameters.keySet().iterator());
} }

View File

@ -15,18 +15,7 @@
package org.springframework.security.wrapper; package org.springframework.security.wrapper;
import org.springframework.security.ui.AbstractProcessingFilter;
import org.springframework.security.ui.savedrequest.Enumerator;
import org.springframework.security.ui.savedrequest.FastHttpDateFormat;
import org.springframework.security.ui.savedrequest.SavedRequest;
import org.springframework.security.util.PortResolver;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import java.text.SimpleDateFormat; import java.text.SimpleDateFormat;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Enumeration; import java.util.Enumeration;
@ -38,12 +27,19 @@ import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.TimeZone; import java.util.TimeZone;
import java.util.Map.Entry;
import javax.servlet.http.Cookie; import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpSession; import javax.servlet.http.HttpSession;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.security.ui.AbstractProcessingFilter;
import org.springframework.security.ui.savedrequest.Enumerator;
import org.springframework.security.ui.savedrequest.FastHttpDateFormat;
import org.springframework.security.ui.savedrequest.SavedRequest;
import org.springframework.security.util.PortResolver;
/** /**
* Provides request parameters, headers and cookies from either an original request or a saved request. * Provides request parameters, headers and cookies from either an original request or a saved request.
@ -121,16 +117,18 @@ public class SavedRequestAwareWrapper extends SecurityContextHolderAwareRequestW
//~ Methods ======================================================================================================== //~ Methods ========================================================================================================
@Override
public Cookie[] getCookies() { public Cookie[] getCookies() {
if (savedRequest == null) { if (savedRequest == null) {
return super.getCookies(); return super.getCookies();
} else { } else {
List cookies = savedRequest.getCookies(); List<Cookie> cookies = savedRequest.getCookies();
return (Cookie[]) cookies.toArray(new Cookie[cookies.size()]); return cookies.toArray(new Cookie[cookies.size()]);
} }
} }
@Override
public long getDateHeader(String name) { public long getDateHeader(String name) {
if (savedRequest == null) { if (savedRequest == null) {
return super.getDateHeader(name); return super.getDateHeader(name);
@ -152,15 +150,16 @@ public class SavedRequestAwareWrapper extends SecurityContextHolderAwareRequestW
} }
} }
@Override
public String getHeader(String name) { public String getHeader(String name) {
if (savedRequest == null) { if (savedRequest == null) {
return super.getHeader(name); return super.getHeader(name);
} else { } else {
String header = null; String header = null;
Iterator iterator = savedRequest.getHeaderValues(name); Iterator<String> iterator = savedRequest.getHeaderValues(name);
while (iterator.hasNext()) { while (iterator.hasNext()) {
header = (String) iterator.next(); header = iterator.next();
break; break;
} }
@ -169,22 +168,25 @@ public class SavedRequestAwareWrapper extends SecurityContextHolderAwareRequestW
} }
} }
@Override
public Enumeration getHeaderNames() { public Enumeration getHeaderNames() {
if (savedRequest == null) { if (savedRequest == null) {
return super.getHeaderNames(); return super.getHeaderNames();
} else { } else {
return new Enumerator(savedRequest.getHeaderNames()); return new Enumerator<String>(savedRequest.getHeaderNames());
} }
} }
@Override
public Enumeration getHeaders(String name) { public Enumeration getHeaders(String name) {
if (savedRequest == null) { if (savedRequest == null) {
return super.getHeaders(name); return super.getHeaders(name);
} else { } else {
return new Enumerator(savedRequest.getHeaderValues(name)); return new Enumerator<String>(savedRequest.getHeaderValues(name));
} }
} }
@Override
public int getIntHeader(String name) { public int getIntHeader(String name) {
if (savedRequest == null) { if (savedRequest == null) {
return super.getIntHeader(name); return super.getIntHeader(name);
@ -199,12 +201,13 @@ public class SavedRequestAwareWrapper extends SecurityContextHolderAwareRequestW
} }
} }
@Override
public Locale getLocale() { public Locale getLocale() {
if (savedRequest == null) { if (savedRequest == null) {
return super.getLocale(); return super.getLocale();
} else { } else {
Locale locale = null; Locale locale = null;
Iterator iterator = savedRequest.getLocales(); Iterator<Locale> iterator = savedRequest.getLocales();
while (iterator.hasNext()) { while (iterator.hasNext()) {
locale = (Locale) iterator.next(); locale = (Locale) iterator.next();
@ -220,23 +223,25 @@ public class SavedRequestAwareWrapper extends SecurityContextHolderAwareRequestW
} }
} }
@Override
public Enumeration getLocales() { public Enumeration getLocales() {
if (savedRequest == null) { if (savedRequest == null) {
return super.getLocales(); return super.getLocales();
} else {
Iterator iterator = savedRequest.getLocales();
if (iterator.hasNext()) {
return new Enumerator(iterator);
} else {
ArrayList results = new ArrayList();
results.add(defaultLocale);
return new Enumerator(results.iterator());
}
} }
Iterator<Locale> iterator = savedRequest.getLocales();
if (iterator.hasNext()) {
return new Enumerator<Locale>(iterator);
}
// Fall back to default locale
ArrayList<Locale> results = new ArrayList<Locale>(1);
results.add(defaultLocale);
return new Enumerator<Locale>(results.iterator());
} }
@Override
public String getMethod() { public String getMethod() {
if (savedRequest == null) { if (savedRequest == null) {
return super.getMethod(); return super.getMethod();
@ -257,83 +262,82 @@ public class SavedRequestAwareWrapper extends SecurityContextHolderAwareRequestW
* If the value from the wrapped request is null, an attempt will be made to retrieve the parameter * If the value from the wrapped request is null, an attempt will be made to retrieve the parameter
* from the SavedRequest, if available.. * from the SavedRequest, if available..
*/ */
@Override
public String getParameter(String name) { public String getParameter(String name) {
String value = super.getParameter(name); String value = super.getParameter(name);
if (value != null || savedRequest == null) { if (value != null || savedRequest == null) {
return value; return value;
} }
String[] values = savedRequest.getParameterValues(name); String[] values = savedRequest.getParameterValues(name);
if (values == null) if (values == null || values.length == 0) {
return null; return null;
for (int i = 0; i < values.length; i++) { }
value = values[i];
break;
}
return value; return values[0];
} }
@Override
public Map getParameterMap() { public Map getParameterMap() {
if (savedRequest == null) { if (savedRequest == null) {
return super.getParameterMap(); return super.getParameterMap();
} }
Set names = getCombinedParameterNames(); Set<String> names = getCombinedParameterNames();
Iterator nameIter = names.iterator(); Map<String, String[]> parameterMap = new HashMap<String, String[]>(names.size());
Map parameterMap = new HashMap(names.size());
while (nameIter.hasNext()) { for (String name : names) {
String name = (String) nameIter.next(); parameterMap.put(name, getParameterValues(name));
parameterMap.put(name, getParameterValues(name)); }
}
return parameterMap; return parameterMap;
} }
private Set getCombinedParameterNames() { private Set<String> getCombinedParameterNames() {
Set names = new HashSet(); Set<String> names = new HashSet<String>();
names.addAll(super.getParameterMap().keySet()); names.addAll(super.getParameterMap().keySet());
if (savedRequest != null) { if (savedRequest != null) {
names.addAll(savedRequest.getParameterMap().keySet()); names.addAll(savedRequest.getParameterMap().keySet());
} }
return names; return names;
} }
@Override
public Enumeration getParameterNames() { public Enumeration getParameterNames() {
return new Enumerator(getCombinedParameterNames()); return new Enumerator(getCombinedParameterNames());
} }
@Override
public String[] getParameterValues(String name) { public String[] getParameterValues(String name) {
if (savedRequest == null) { if (savedRequest == null) {
return super.getParameterValues(name); return super.getParameterValues(name);
} }
String[] savedRequestParams = savedRequest.getParameterValues(name); String[] savedRequestParams = savedRequest.getParameterValues(name);
String[] wrappedRequestParams = super.getParameterValues(name); String[] wrappedRequestParams = super.getParameterValues(name);
if (savedRequestParams == null) { if (savedRequestParams == null) {
return wrappedRequestParams; return wrappedRequestParams;
} }
if (wrappedRequestParams == null) { if (wrappedRequestParams == null) {
return savedRequestParams; return savedRequestParams;
} }
// We have params in both saved and wrapped requests so have to merge them // We have parameters in both saved and wrapped requests so have to merge them
List wrappedParamsList = Arrays.asList(wrappedRequestParams); List<String> wrappedParamsList = Arrays.asList(wrappedRequestParams);
List combinedParams = new ArrayList(wrappedParamsList); List<String> combinedParams = new ArrayList<String>(wrappedParamsList);
// We want to add all parameters of the saved request *apart from* duplicates of those already added // We want to add all parameters of the saved request *apart from* duplicates of those already added
for (int i = 0; i < savedRequestParams.length; i++) { for (int i = 0; i < savedRequestParams.length; i++) {
if (!wrappedParamsList.contains(savedRequestParams[i])) { if (!wrappedParamsList.contains(savedRequestParams[i])) {
combinedParams.add(savedRequestParams[i]); combinedParams.add(savedRequestParams[i]);
} }
} }
return (String[]) combinedParams.toArray(new String[combinedParams.size()]); return combinedParams.toArray(new String[combinedParams.size()]);
} }
} }

View File

@ -89,6 +89,7 @@ public class SecurityContextHolderAwareRequestWrapper extends HttpServletRequest
* *
* @return the username or <code>null</code> if unavailable * @return the username or <code>null</code> if unavailable
*/ */
@Override
public String getRemoteUser() { public String getRemoteUser() {
Authentication auth = getAuthentication(); Authentication auth = getAuthentication();
@ -109,6 +110,7 @@ public class SecurityContextHolderAwareRequestWrapper extends HttpServletRequest
* *
* @return the <code>Authentication</code>, or <code>null</code> * @return the <code>Authentication</code>, or <code>null</code>
*/ */
@Override
public Principal getUserPrincipal() { public Principal getUserPrincipal() {
Authentication auth = getAuthentication(); Authentication auth = getAuthentication();
@ -158,6 +160,7 @@ public class SecurityContextHolderAwareRequestWrapper extends HttpServletRequest
* @return <code>true</code> if an <b>exact</b> (case sensitive) matching granted authority is located, * @return <code>true</code> if an <b>exact</b> (case sensitive) matching granted authority is located,
* <code>false</code> otherwise * <code>false</code> otherwise
*/ */
@Override
public boolean isUserInRole(String role) { public boolean isUserInRole(String role) {
return isGranted(role); return isGranted(role);
} }

View File

@ -2,86 +2,200 @@ package org.springframework.security.wrapper;
import static org.junit.Assert.*; import static org.junit.Assert.*;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.Enumeration;
import java.util.Locale;
import javax.servlet.http.Cookie;
import org.junit.Test; import org.junit.Test;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.ui.AbstractProcessingFilter; import org.springframework.security.ui.AbstractProcessingFilter;
import org.springframework.security.ui.savedrequest.FastHttpDateFormat;
import org.springframework.security.ui.savedrequest.SavedRequest; import org.springframework.security.ui.savedrequest.SavedRequest;
import org.springframework.security.util.PortResolverImpl; import org.springframework.security.util.PortResolverImpl;
public class SavedRequestAwareWrapperTests { public class SavedRequestAwareWrapperTests {
@Test private SavedRequestAwareWrapper createWrapper(MockHttpServletRequest requestToSave, MockHttpServletRequest requestToWrap) {
/* SEC-830. Assume we have a request to /someUrl?action=foo (the saved request) if (requestToSave != null) {
* and then RequestDispatcher.forward() it to /someUrl?action=bar. SavedRequest savedRequest = new SavedRequest(requestToSave, new PortResolverImpl());
* What should action parameter be before and during the forward? requestToWrap.getSession().setAttribute(AbstractProcessingFilter.SPRING_SECURITY_SAVED_REQUEST_KEY, savedRequest);
**/ }
public void wrappedRequestParameterTakesPrecedenceOverSavedRequest() { return new SavedRequestAwareWrapper(requestToWrap, new PortResolverImpl(),"ROLE_");
MockHttpServletRequest request = new MockHttpServletRequest(); }
request.setParameter("action", "foo");
SavedRequest savedRequest = new SavedRequest(request, new PortResolverImpl());
MockHttpServletRequest request2 = new MockHttpServletRequest();
request2.getSession().setAttribute(AbstractProcessingFilter.SPRING_SECURITY_SAVED_REQUEST_KEY, savedRequest);
SavedRequestAwareWrapper wrapper = new SavedRequestAwareWrapper(request2, new PortResolverImpl(), "ROLE_");
assertEquals("foo", wrapper.getParameter("action"));
// The request after forward
request2.setParameter("action", "bar");
assertEquals("bar", wrapper.getParameter("action"));
// Both values should be set, but "bar" should be first
assertEquals(2, wrapper.getParameterValues("action").length);
assertEquals("bar", wrapper.getParameterValues("action")[0]);
}
@Test @Test
public void savedRequestDoesntCreateDuplicateParams() { public void wrappedRequestCookiesAreReturnedIfNoSavedRequestIsSet() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest wrappedRequest = new MockHttpServletRequest();
request.setParameter("action", "foo"); wrappedRequest.setCookies(new Cookie[] {new Cookie("cookie", "fromwrapped")});
SavedRequest savedRequest = new SavedRequest(request, new PortResolverImpl()); SavedRequestAwareWrapper wrapper = createWrapper(null, wrappedRequest);
MockHttpServletRequest request2 = new MockHttpServletRequest(); assertEquals(1, wrapper.getCookies().length);
request2.getSession().setAttribute(AbstractProcessingFilter.SPRING_SECURITY_SAVED_REQUEST_KEY, savedRequest); assertEquals("fromwrapped", wrapper.getCookies()[0].getValue());
request2.setParameter("action", "foo"); }
SavedRequestAwareWrapper wrapper = new SavedRequestAwareWrapper(request2, new PortResolverImpl(), "ROLE_");
assertEquals(1, wrapper.getParameterValues("action").length);
assertEquals(1, wrapper.getParameterMap().size());
assertEquals(1, ((String[])wrapper.getParameterMap().get("action")).length);
}
@Test @Test
public void savedRequestHeadersTakePrecedence() { public void savedRequestCookiesAreReturnedIfSavedRequestIsSet() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest savedRequest = new MockHttpServletRequest();
request.addHeader("Authorization","foo"); savedRequest.setCookies(new Cookie[] {new Cookie("cookie", "fromsaved")});
SavedRequest savedRequest = new SavedRequest(request, new PortResolverImpl()); SavedRequestAwareWrapper wrapper = createWrapper(savedRequest, new MockHttpServletRequest());
assertEquals(1, wrapper.getCookies().length);
assertEquals("fromsaved", wrapper.getCookies()[0].getValue());
}
MockHttpServletRequest request2 = new MockHttpServletRequest(); @Test
request2.addHeader("Authorization","bar"); public void savedRequesthHeaderIsReturnedIfSavedRequestIsSet() throws Exception {
request2.getSession().setAttribute(AbstractProcessingFilter.SPRING_SECURITY_SAVED_REQUEST_KEY, savedRequest); MockHttpServletRequest savedRequest = new MockHttpServletRequest();
savedRequest.addHeader("header", "savedheader");
SavedRequestAwareWrapper wrapper = createWrapper(savedRequest, new MockHttpServletRequest());
SavedRequestAwareWrapper wrapper = new SavedRequestAwareWrapper(request2, new PortResolverImpl(), "ROLE_"); assertNull(wrapper.getHeader("nonexistent"));
Enumeration headers = wrapper.getHeaders("nonexistent");
assertFalse(headers.hasMoreElements());
assertEquals("foo", wrapper.getHeader("Authorization")); assertEquals("savedheader", wrapper.getHeader("header"));
} headers = wrapper.getHeaders("header");
assertTrue(headers.hasMoreElements());
assertEquals("savedheader", headers.nextElement());
assertFalse(headers.hasMoreElements());
assertTrue(wrapper.getHeaderNames().hasMoreElements());
assertEquals("header", wrapper.getHeaderNames().nextElement());
}
@Test @Test
public void getParameterValuesReturnsNullIfParameterIsntSet() { public void wrappedRequestHeaderIsReturnedIfSavedRequestIsNotSet() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest wrappedRequest = new MockHttpServletRequest();
SavedRequestAwareWrapper wrapper = new SavedRequestAwareWrapper(request, new PortResolverImpl(), "ROLE_"); wrappedRequest.addHeader("header", "wrappedheader");
assertNull(wrapper.getParameterValues("action")); SavedRequestAwareWrapper wrapper = createWrapper(null, wrappedRequest);
assertNull(wrapper.getParameterMap().get("action"));
} assertNull(wrapper.getHeader("nonexistent"));
Enumeration headers = wrapper.getHeaders("nonexistent");
assertFalse(headers.hasMoreElements());
assertEquals("wrappedheader", wrapper.getHeader("header"));
headers = wrapper.getHeaders("header");
assertTrue(headers.hasMoreElements());
assertEquals("wrappedheader", headers.nextElement());
assertFalse(headers.hasMoreElements());
assertTrue(wrapper.getHeaderNames().hasMoreElements());
assertEquals("header", wrapper.getHeaderNames().nextElement());
}
@Test
/* SEC-830. Assume we have a request to /someUrl?action=foo (the saved request)
* and then RequestDispatcher.forward() it to /someUrl?action=bar.
* What should action parameter be before and during the forward?
**/
public void wrappedRequestParameterTakesPrecedenceOverSavedRequest() {
MockHttpServletRequest savedRequest = new MockHttpServletRequest();
savedRequest.setParameter("action", "foo");
MockHttpServletRequest wrappedRequest = new MockHttpServletRequest();
SavedRequestAwareWrapper wrapper = createWrapper(savedRequest, wrappedRequest);
assertEquals("foo", wrapper.getParameter("action"));
// The request after forward
wrappedRequest.setParameter("action", "bar");
assertEquals("bar", wrapper.getParameter("action"));
// Both values should be set, but "bar" should be first
assertEquals(2, wrapper.getParameterValues("action").length);
assertEquals("bar", wrapper.getParameterValues("action")[0]);
}
@Test
public void savedRequestDoesntCreateDuplicateParams() {
MockHttpServletRequest savedRequest = new MockHttpServletRequest();
savedRequest.setParameter("action", "foo");
MockHttpServletRequest wrappedRequest = new MockHttpServletRequest();
wrappedRequest.setParameter("action", "foo");
SavedRequestAwareWrapper wrapper = createWrapper(savedRequest, wrappedRequest);
assertEquals(1, wrapper.getParameterValues("action").length);
assertEquals(1, wrapper.getParameterMap().size());
assertEquals(1, ((String[])wrapper.getParameterMap().get("action")).length);
}
@Test
public void savedRequestHeadersTakePrecedence() {
MockHttpServletRequest savedRequest = new MockHttpServletRequest();
savedRequest.addHeader("Authorization","foo");
MockHttpServletRequest wrappedRequest = new MockHttpServletRequest();
wrappedRequest.addHeader("Authorization","bar");
SavedRequestAwareWrapper wrapper = createWrapper(savedRequest, wrappedRequest);
assertEquals("foo", wrapper.getHeader("Authorization"));
}
@Test
public void getParameterValuesReturnsNullIfParameterIsntSet() {
MockHttpServletRequest wrappedRequest = new MockHttpServletRequest();
SavedRequestAwareWrapper wrapper = new SavedRequestAwareWrapper(wrappedRequest, new PortResolverImpl(), "ROLE_");
assertNull(wrapper.getParameterValues("action"));
assertNull(wrapper.getParameterMap().get("action"));
}
@Test
public void getParameterValuesReturnsCombinedSavedAndWrappedRequestValues() {
MockHttpServletRequest savedRequest = new MockHttpServletRequest();
savedRequest.setParameter("action", "foo");
MockHttpServletRequest wrappedRequest = new MockHttpServletRequest();
SavedRequestAwareWrapper wrapper = createWrapper(savedRequest, wrappedRequest);
assertArrayEquals(new Object[] {"foo"}, wrapper.getParameterValues("action"));
wrappedRequest.setParameter("action", "bar");
assertArrayEquals(new Object[] {"bar","foo"}, wrapper.getParameterValues("action"));
// Check map is consistent
String[] valuesFromMap = (String[]) wrapper.getParameterMap().get("action");
assertEquals(2, valuesFromMap.length);
assertEquals("bar", valuesFromMap[0]);
}
@Test
public void expecteDateHeaderIsReturnedFromSavedAndWrappedRequests() throws Exception {
SimpleDateFormat formatter = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US);
String nowString = FastHttpDateFormat.getCurrentDate();
Date now = formatter.parse(nowString);
MockHttpServletRequest request = new MockHttpServletRequest();
request.addHeader("header", nowString);
SavedRequestAwareWrapper wrapper = createWrapper(request, new MockHttpServletRequest());
assertEquals(now.getTime(), wrapper.getDateHeader("header"));
assertEquals(-1L, wrapper.getDateHeader("nonexistent"));
// Now try with no saved request
request = new MockHttpServletRequest();
request.addHeader("header", now);
wrapper = createWrapper(null, request);
assertEquals(now.getTime(), wrapper.getDateHeader("header"));
}
@Test(expected=IllegalArgumentException.class)
public void invalidDateHeaderIsRejected() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest();
request.addHeader("header", "notadate");
SavedRequestAwareWrapper wrapper = createWrapper(request, new MockHttpServletRequest());
wrapper.getDateHeader("header");
}
@Test
public void correctHttpMethodIsReturned() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest("PUT", "/notused");
SavedRequestAwareWrapper wrapper = createWrapper(request, new MockHttpServletRequest("GET", "/notused"));
assertEquals("PUT", wrapper.getMethod());
wrapper = createWrapper(null, request);
assertEquals("PUT", wrapper.getMethod());
}
@Test
public void correctIntHeaderIsReturned() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest();
request.addHeader("header", "999");
request.addHeader("header", "1000");
SavedRequestAwareWrapper wrapper = createWrapper(request, new MockHttpServletRequest());
assertEquals(999, wrapper.getIntHeader("header"));
assertEquals(-1, wrapper.getIntHeader("nonexistent"));
wrapper = createWrapper(null, request);
assertEquals(999, wrapper.getIntHeader("header"));
}
@Test
public void getParameterValuesReturnsCombinedValues() {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setParameter("action", "foo");
SavedRequest savedRequest = new SavedRequest(request, new PortResolverImpl());
MockHttpServletRequest request2 = new MockHttpServletRequest();
request2.getSession().setAttribute(AbstractProcessingFilter.SPRING_SECURITY_SAVED_REQUEST_KEY, savedRequest);
SavedRequestAwareWrapper wrapper = new SavedRequestAwareWrapper(request2, new PortResolverImpl(), "ROLE_");
assertArrayEquals(new Object[] {"foo"}, wrapper.getParameterValues("action"));
request2.setParameter("action", "bar");
assertArrayEquals(new Object[] {"bar","foo"}, wrapper.getParameterValues("action"));
// Check map is consistent
String[] valuesFromMap = (String[]) wrapper.getParameterMap().get("action");
assertEquals(2, valuesFromMap.length);
assertEquals("bar", valuesFromMap[0]);
}
} }