diff --git a/web/src/main/java/org/springframework/security/web/FilterChainProxy.java b/web/src/main/java/org/springframework/security/web/FilterChainProxy.java index 5b2d439228..78d1f770d8 100644 --- a/web/src/main/java/org/springframework/security/web/FilterChainProxy.java +++ b/web/src/main/java/org/springframework/security/web/FilterChainProxy.java @@ -40,6 +40,7 @@ import org.springframework.security.web.firewall.HttpFirewall; import org.springframework.security.web.firewall.RequestRejectedException; import org.springframework.security.web.firewall.RequestRejectedHandler; import org.springframework.security.web.firewall.StrictHttpFirewall; +import org.springframework.security.web.util.ThrowableAnalyzer; import org.springframework.security.web.util.UrlUtils; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; @@ -154,6 +155,8 @@ public class FilterChainProxy extends GenericFilterBean { private RequestRejectedHandler requestRejectedHandler = new DefaultRequestRejectedHandler(); + private ThrowableAnalyzer throwableAnalyzer = new ThrowableAnalyzer(); + public FilterChainProxy() { } @@ -182,8 +185,15 @@ public class FilterChainProxy extends GenericFilterBean { request.setAttribute(FILTER_APPLIED, Boolean.TRUE); doFilterInternal(request, response, chain); } - catch (RequestRejectedException ex) { - this.requestRejectedHandler.handle((HttpServletRequest) request, (HttpServletResponse) response, ex); + catch (Exception ex) { + Throwable[] causeChain = this.throwableAnalyzer.determineCauseChain(ex); + Throwable requestRejectedException = this.throwableAnalyzer + .getFirstThrowableOfType(RequestRejectedException.class, causeChain); + if (!(requestRejectedException instanceof RequestRejectedException)) { + throw ex; + } + this.requestRejectedHandler.handle((HttpServletRequest) request, (HttpServletResponse) response, + (RequestRejectedException) requestRejectedException); } finally { SecurityContextHolder.clearContext(); diff --git a/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java b/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java index 59db2f705f..49a0f283b4 100644 --- a/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java +++ b/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java @@ -49,6 +49,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.willAnswer; +import static org.mockito.BDDMockito.willThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; @@ -252,4 +253,18 @@ public class FilterChainProxyTests { verify(rjh).handle(eq(this.request), eq(this.response), eq((requestRejectedException))); } + @Test + public void requestRejectedHandlerIsCalledIfFirewallThrowsWrappedRequestRejectedException() throws Exception { + HttpFirewall fw = mock(HttpFirewall.class); + RequestRejectedHandler rjh = mock(RequestRejectedHandler.class); + this.fcp.setFirewall(fw); + this.fcp.setRequestRejectedHandler(rjh); + RequestRejectedException requestRejectedException = new RequestRejectedException("Contains illegal chars"); + ServletException servletException = new ServletException(requestRejectedException); + given(fw.getFirewalledRequest(this.request)).willReturn(mock(FirewalledRequest.class)); + willThrow(servletException).given(this.chain).doFilter(any(), any()); + this.fcp.doFilter(this.request, this.response, this.chain); + verify(rjh).handle(eq(this.request), eq(this.response), eq((requestRejectedException))); + } + }