diff --git a/web/src/main/java/org/springframework/security/web/FilterChainProxy.java b/web/src/main/java/org/springframework/security/web/FilterChainProxy.java index e4df6edbe9..34b7497ae7 100644 --- a/web/src/main/java/org/springframework/security/web/FilterChainProxy.java +++ b/web/src/main/java/org/springframework/security/web/FilterChainProxy.java @@ -125,6 +125,8 @@ public class FilterChainProxy extends GenericFilterBean { //~ Instance fields ================================================================================================ + private final static String FILTER_APPLIED = FilterChainProxy.class.getName().concat(".APPLIED"); + private List filterChains; private FilterChainValidator filterChainValidator = new NullFilterChainValidator(); @@ -151,11 +153,17 @@ public class FilterChainProxy extends GenericFilterBean { public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { - try { + boolean clearContext = request.getAttribute(FILTER_APPLIED) == null; + if(clearContext) { + try { + request.setAttribute(FILTER_APPLIED, Boolean.TRUE); + doFilterInternal(request, response, chain); + } finally { + SecurityContextHolder.clearContext(); + request.removeAttribute(FILTER_APPLIED); + } + } else { doFilterInternal(request, response, chain); - } finally { - // SEC-1950 - SecurityContextHolder.clearContext(); } } diff --git a/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java b/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java index 78f9b5f500..71f1d91ab5 100644 --- a/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java +++ b/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java @@ -196,4 +196,31 @@ public class FilterChainProxyTests { assertNull(SecurityContextHolder.getContext().getAuthentication()); } + + // SEC-2027 + @Test + public void doFilterClearsSecurityContextHolderOnceOnForwards() throws Exception { + final FilterChain innerChain = mock(FilterChain.class); + when(matcher.matches(any(HttpServletRequest.class))).thenReturn(true); + doAnswer(new Answer() { + public Object answer(InvocationOnMock inv) throws Throwable { + TestingAuthenticationToken expected = new TestingAuthenticationToken("username", "password"); + SecurityContextHolder.getContext().setAuthentication(expected); + doAnswer(new Answer() { + public Object answer(InvocationOnMock inv) throws Throwable { + innerChain.doFilter(request, response); + return null; + } + }).when(filter).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class));; + fcp.doFilter(request, response, innerChain); + assertSame(expected, SecurityContextHolder.getContext().getAuthentication()); + return null; + } + }).when(filter).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class)); + + fcp.doFilter(request, response, chain); + + verify(innerChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + assertNull(SecurityContextHolder.getContext().getAuthentication()); + } }