diff --git a/web/src/main/java/org/springframework/security/web/FilterChainProxy.java b/web/src/main/java/org/springframework/security/web/FilterChainProxy.java index 06a491439e..f7cd7b14ac 100644 --- a/web/src/main/java/org/springframework/security/web/FilterChainProxy.java +++ b/web/src/main/java/org/springframework/security/web/FilterChainProxy.java @@ -31,7 +31,6 @@ import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.ServletRequest; -import javax.servlet.ServletRequestWrapper; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -115,7 +114,7 @@ import java.util.*; * @author Carlos Sanchez * @author Ben Alex * @author Luke Taylor - * + * @author Rob Winch */ public class FilterChainProxy extends GenericFilterBean { //~ Static fields/initializers ===================================================================================== @@ -160,7 +159,7 @@ public class FilterChainProxy extends GenericFilterBean { return; } - VirtualFilterChain vfc = new VirtualFilterChain(url, chain, filters); + VirtualFilterChain vfc = new VirtualFilterChain(url, chain, filters, fwRequest); vfc.doFilter(fwRequest, fwResponse); } @@ -287,15 +286,17 @@ public class FilterChainProxy extends GenericFilterBean { private static class VirtualFilterChain implements FilterChain { private final FilterChain originalChain; private final List additionalFilters; + private final FirewalledRequest firewalledRequest; private final String url; private final int size; private int currentPosition = 0; - private VirtualFilterChain(String url, FilterChain chain, List additionalFilters) { + private VirtualFilterChain(String url, FilterChain chain, List additionalFilters, FirewalledRequest firewalledRequest) { this.originalChain = chain; this.url = url; this.additionalFilters = additionalFilters; this.size = additionalFilters.size(); + this.firewalledRequest = firewalledRequest; } public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException { @@ -305,7 +306,7 @@ public class FilterChainProxy extends GenericFilterBean { } // Deactivate path stripping as we exit the security filter chain - resetWrapper(request); + this.firewalledRequest.reset(); originalChain.doFilter(request, response); } else { @@ -322,16 +323,6 @@ public class FilterChainProxy extends GenericFilterBean { nextFilter.doFilter(request, response, this); } } - - private void resetWrapper(ServletRequest request) { - while (request instanceof ServletRequestWrapper) { - if (request instanceof FirewalledRequest) { - ((FirewalledRequest)request).reset(); - break; - } - request = ((ServletRequestWrapper)request).getRequest(); - } - } } public interface FilterChainValidator { diff --git a/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java b/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java index 89e2c8de95..4e2e98e3ec 100644 --- a/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java +++ b/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java @@ -22,6 +22,7 @@ import java.util.*; /** * @author Luke Taylor + * @author Rob Winch */ @SuppressWarnings({"unchecked"}) public class FilterChainProxyTests { @@ -126,4 +127,30 @@ public class FilterChainProxyTests { fcp.doFilter(request, response, chain); verify(fwr).reset(); } -} + + // SEC-1639 + @Test + public void bothWrappersAreResetWithNestedFcps() throws Exception { + HttpFirewall fw = mock(HttpFirewall.class); + FilterChainProxy firstFcp = new FilterChainProxy(); + LinkedHashMap fcm = new LinkedHashMap(); + fcm.put(matcher, Arrays.asList(fcp)); + firstFcp.setFilterChainMap(fcm); + firstFcp.setFirewall(fw); + fcp.setFirewall(fw); + FirewalledRequest firstFwr = mock(FirewalledRequest.class, "firstFwr"); + when(firstFwr.getRequestURI()).thenReturn("/"); + when(firstFwr.getContextPath()).thenReturn(""); + FirewalledRequest fwr = mock(FirewalledRequest.class, "fwr"); + when(fwr.getRequestURI()).thenReturn("/"); + when(fwr.getContextPath()).thenReturn(""); + when(fw.getFirewalledRequest(request)).thenReturn(firstFwr); + when(fw.getFirewalledRequest(firstFwr)).thenReturn(fwr); + when(fwr.getRequest()).thenReturn(firstFwr); + when(firstFwr.getRequest()).thenReturn(request); + when(matcher.matches(any(HttpServletRequest.class))).thenReturn(true); + firstFcp.doFilter(request, response, chain); + verify(firstFwr).reset(); + verify(fwr).reset(); + } +} \ No newline at end of file