diff --git a/core/src/main/java/org/springframework/security/util/FilterChainProxy.java b/core/src/main/java/org/springframework/security/util/FilterChainProxy.java index 28adf99e29..1922227f4f 100644 --- a/core/src/main/java/org/springframework/security/util/FilterChainProxy.java +++ b/core/src/main/java/org/springframework/security/util/FilterChainProxy.java @@ -96,6 +96,7 @@ import java.util.*; * @author Carlos Sanchez * @author Ben Alex * @author Luke Taylor + * @author Rob Winch * * @version $Id$ */ @@ -183,7 +184,7 @@ public class FilterChainProxy implements Filter, InitializingBean, ApplicationCo return; } - VirtualFilterChain virtualFilterChain = new VirtualFilterChain(fi, filters); + VirtualFilterChain virtualFilterChain = new VirtualFilterChain(fi, filters, fwRequest); virtualFilterChain.doFilter(fi.getRequest(), fi.getResponse()); } @@ -376,11 +377,13 @@ public class FilterChainProxy implements Filter, InitializingBean, ApplicationCo private static class VirtualFilterChain implements FilterChain { private FilterInvocation fi; private List additionalFilters; + private FirewalledRequest firewalledRequest; private int currentPosition = 0; - private VirtualFilterChain(FilterInvocation filterInvocation, List additionalFilters) { + private VirtualFilterChain(FilterInvocation filterInvocation, List additionalFilters, FirewalledRequest firewalledRequest) { this.fi = filterInvocation; this.additionalFilters = additionalFilters; + this.firewalledRequest = firewalledRequest; } public void doFilter(ServletRequest request, ServletResponse response) @@ -391,7 +394,7 @@ public class FilterChainProxy implements Filter, InitializingBean, ApplicationCo + " reached end of additional filter chain; proceeding with original chain"); } // Deactivate path stripping as we exit the security filter chain - resetWrapper(request); + this.firewalledRequest.reset(); fi.getChain().doFilter(request, response); } else { @@ -408,16 +411,6 @@ public class FilterChainProxy implements Filter, InitializingBean, ApplicationCo nextFilter.doFilter(request, response, this); } } - - private void resetWrapper(ServletRequest request) { - while (request instanceof ServletRequestWrapper) { - if (request instanceof FirewalledRequest) { - ((FirewalledRequest)request).reset(); - break; - } - request = ((ServletRequestWrapper)request).getRequest(); - } - } } } diff --git a/core/src/test/java/org/springframework/security/util/FilterChainProxyTests.java b/core/src/test/java/org/springframework/security/util/FilterChainProxyTests.java index 5876c4c12e..2d038debd1 100644 --- a/core/src/test/java/org/springframework/security/util/FilterChainProxyTests.java +++ b/core/src/test/java/org/springframework/security/util/FilterChainProxyTests.java @@ -10,6 +10,7 @@ import org.mockito.stubbing.Answer; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.firewall.FirewalledRequest; +import org.springframework.security.firewall.HttpFirewall; import javax.servlet.Filter; import javax.servlet.FilterChain; @@ -19,6 +20,7 @@ import java.util.*; /** * @author Luke Taylor + * @author Rob Winch */ @SuppressWarnings({"unchecked"}) public class FilterChainProxyTests { @@ -100,4 +102,29 @@ public class FilterChainProxyTests { verify(chain).doFilter(any(FirewalledRequest.class), any(HttpServletResponse.class)); } + // SEC-1639 + @Test + public void bothWrappersAreResetWithNestedFcps() throws Exception { + HttpFirewall fw = mock(HttpFirewall.class); + FilterChainProxy firstFcp = new FilterChainProxy(); + LinkedHashMap fcm = new LinkedHashMap(); + fcm.put("/match", Arrays.asList(fcp)); + firstFcp.setFilterChainMap(fcm); + firstFcp.setFirewall(fw); + fcp.setFirewall(fw); + FirewalledRequest firstFwr = mock(FirewalledRequest.class, "firstFwr"); + when(firstFwr.getRequestURI()).thenReturn("/match"); + when(firstFwr.getContextPath()).thenReturn(""); + FirewalledRequest fwr = mock(FirewalledRequest.class, "fwr"); + when(fwr.getRequestURI()).thenReturn("/match"); + when(fwr.getContextPath()).thenReturn(""); + when(fw.getFirewalledResponse(any(HttpServletResponse.class))).thenReturn(response); + when(fw.getFirewalledRequest(request)).thenReturn(firstFwr); + when(fw.getFirewalledRequest(firstFwr)).thenReturn(fwr); + when(fwr.getRequest()).thenReturn(firstFwr); + when(firstFwr.getRequest()).thenReturn(request); + firstFcp.doFilter(request, response, chain); + verify(firstFwr).reset(); + verify(fwr).reset(); + } }