diff --git a/web/src/main/java/org/springframework/security/web/FilterChainProxy.java b/web/src/main/java/org/springframework/security/web/FilterChainProxy.java index 8f4e564bb7..0d996fa05c 100644 --- a/web/src/main/java/org/springframework/security/web/FilterChainProxy.java +++ b/web/src/main/java/org/springframework/security/web/FilterChainProxy.java @@ -33,7 +33,6 @@ import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.ServletException; import javax.servlet.ServletRequest; -import javax.servlet.ServletRequestWrapper; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -116,7 +115,7 @@ import java.util.*; * @author Carlos Sanchez * @author Ben Alex * @author Luke Taylor - * + * @author Rob Winch */ public class FilterChainProxy extends GenericFilterBean { //~ Static fields/initializers ===================================================================================== @@ -165,7 +164,7 @@ public class FilterChainProxy extends GenericFilterBean { return; } - VirtualFilterChain vfc = new VirtualFilterChain(url, chain, filters); + VirtualFilterChain vfc = new VirtualFilterChain(url, chain, filters, fwRequest); vfc.doFilter(fwRequest, fwResponse); } @@ -347,13 +346,15 @@ public class FilterChainProxy extends GenericFilterBean { private static class VirtualFilterChain implements FilterChain { private final FilterChain originalChain; private final List additionalFilters; + private final FirewalledRequest firewalledRequest; private final String url; private int currentPosition = 0; - private VirtualFilterChain(String url, FilterChain chain, List additionalFilters) { + private VirtualFilterChain(String url, FilterChain chain, List additionalFilters, FirewalledRequest firewalledRequest) { this.originalChain = chain; this.url = url; this.additionalFilters = additionalFilters; + this.firewalledRequest = firewalledRequest; } public void doFilter(final ServletRequest request, final ServletResponse response) throws IOException, ServletException { @@ -363,7 +364,7 @@ public class FilterChainProxy extends GenericFilterBean { } // Deactivate path stripping as we exit the security filter chain - resetWrapper(request); + this.firewalledRequest.reset(); originalChain.doFilter(request, response); } else { @@ -380,16 +381,6 @@ public class FilterChainProxy extends GenericFilterBean { nextFilter.doFilter(request, response, this); } } - - private void resetWrapper(ServletRequest request) { - while (request instanceof ServletRequestWrapper) { - if (request instanceof FirewalledRequest) { - ((FirewalledRequest)request).reset(); - break; - } - request = ((ServletRequestWrapper)request).getRequest(); - } - } } public interface FilterChainValidator { diff --git a/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java b/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java index 62cb1fb68a..385034dc9e 100644 --- a/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java +++ b/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java @@ -20,6 +20,7 @@ import java.util.*; /** * @author Luke Taylor + * @author Rob Winch */ @SuppressWarnings({"unchecked"}) public class FilterChainProxyTests { @@ -114,4 +115,28 @@ public class FilterChainProxyTests { verify(fwr).reset(); } -} + // SEC-1639 + @Test + public void bothWrappersAreResetWithNestedFcps() throws Exception { + HttpFirewall fw = mock(HttpFirewall.class); + FilterChainProxy firstFcp = new FilterChainProxy(); + LinkedHashMap fcm = new LinkedHashMap(); + fcm.put("/match", Arrays.asList(fcp)); + firstFcp.setFilterChainMap(fcm); + firstFcp.setFirewall(fw); + fcp.setFirewall(fw); + FirewalledRequest firstFwr = mock(FirewalledRequest.class, "firstFwr"); + when(firstFwr.getRequestURI()).thenReturn("/match"); + when(firstFwr.getContextPath()).thenReturn(""); + FirewalledRequest fwr = mock(FirewalledRequest.class, "fwr"); + when(fwr.getRequestURI()).thenReturn("/match"); + when(fwr.getContextPath()).thenReturn(""); + when(fw.getFirewalledRequest(request)).thenReturn(firstFwr); + when(fw.getFirewalledRequest(firstFwr)).thenReturn(fwr); + when(fwr.getRequest()).thenReturn(firstFwr); + when(firstFwr.getRequest()).thenReturn(request); + firstFcp.doFilter(request, response, chain); + verify(firstFwr).reset(); + verify(fwr).reset(); + } +} \ No newline at end of file