diff --git a/web/src/main/java/org/springframework/security/web/context/SecurityContextHolderFilter.java b/web/src/main/java/org/springframework/security/web/context/SecurityContextHolderFilter.java index 198bfb59b0..9d5360663c 100644 --- a/web/src/main/java/org/springframework/security/web/context/SecurityContextHolderFilter.java +++ b/web/src/main/java/org/springframework/security/web/context/SecurityContextHolderFilter.java @@ -21,6 +21,8 @@ import java.util.function.Supplier; import javax.servlet.FilterChain; import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -28,7 +30,7 @@ import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.util.Assert; -import org.springframework.web.filter.OncePerRequestFilter; +import org.springframework.web.filter.GenericFilterBean; /** * A {@link javax.servlet.Filter} that uses the {@link SecurityContextRepository} to @@ -40,17 +42,18 @@ import org.springframework.web.filter.OncePerRequestFilter; * mechanisms to choose individually if authentication should be persisted. * * @author Rob Winch + * @author Marcus da Coregio * @since 5.7 */ -public class SecurityContextHolderFilter extends OncePerRequestFilter { +public class SecurityContextHolderFilter extends GenericFilterBean { + + private static final String FILTER_APPLIED = SecurityContextHolderFilter.class.getName() + ".APPLIED"; private final SecurityContextRepository securityContextRepository; private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder .getContextHolderStrategy(); - private boolean shouldNotFilterErrorDispatch; - /** * Creates a new instance. * @param securityContextRepository the repository to use. Cannot be null. @@ -61,23 +64,29 @@ public class SecurityContextHolderFilter extends OncePerRequestFilter { } @Override - protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) + throws IOException, ServletException { + doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain); + } + + private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws ServletException, IOException { + if (request.getAttribute(FILTER_APPLIED) != null) { + chain.doFilter(request, response); + return; + } + request.setAttribute(FILTER_APPLIED, Boolean.TRUE); Supplier deferredContext = this.securityContextRepository.loadDeferredContext(request); try { this.securityContextHolderStrategy.setDeferredContext(deferredContext); - filterChain.doFilter(request, response); + chain.doFilter(request, response); } finally { this.securityContextHolderStrategy.clearContext(); + request.removeAttribute(FILTER_APPLIED); } } - @Override - protected boolean shouldNotFilterErrorDispatch() { - return this.shouldNotFilterErrorDispatch; - } - /** * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. @@ -89,13 +98,4 @@ public class SecurityContextHolderFilter extends OncePerRequestFilter { this.securityContextHolderStrategy = securityContextHolderStrategy; } - /** - * Disables {@link SecurityContextHolderFilter} for error dispatch. - * @param shouldNotFilterErrorDispatch if the Filter should be disabled for error - * dispatch. Default is false. - */ - public void setShouldNotFilterErrorDispatch(boolean shouldNotFilterErrorDispatch) { - this.shouldNotFilterErrorDispatch = shouldNotFilterErrorDispatch; - } - } diff --git a/web/src/test/java/org/springframework/security/web/context/SecurityContextHolderFilterTests.java b/web/src/test/java/org/springframework/security/web/context/SecurityContextHolderFilterTests.java index 6b70ad00e0..96c7de421f 100644 --- a/web/src/test/java/org/springframework/security/web/context/SecurityContextHolderFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/context/SecurityContextHolderFilterTests.java @@ -18,6 +18,7 @@ package org.springframework.security.web.context; import java.util.function.Supplier; +import javax.servlet.DispatcherType; import javax.servlet.FilterChain; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -26,11 +27,15 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; import org.mockito.ArgumentCaptor; import org.mockito.Captor; +import org.mockito.InOrder; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.mock.web.MockFilterChain; import org.springframework.security.authentication.TestAuthentication; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; @@ -40,11 +45,17 @@ import org.springframework.security.core.context.SecurityContextImpl; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; @ExtendWith(MockitoExtension.class) class SecurityContextHolderFilterTests { + private static final String FILTER_APPLIED = "org.springframework.security.web.context.SecurityContextHolderFilter.APPLIED"; + @Mock private SecurityContextRepository repository; @@ -105,14 +116,38 @@ class SecurityContextHolderFilterTests { } @Test - void shouldNotFilterErrorDispatchWhenDefault() { - assertThat(this.filter.shouldNotFilterErrorDispatch()).isFalse(); + void doFilterWhenFilterAppliedThenDoNothing() throws Exception { + given(this.request.getAttribute(FILTER_APPLIED)).willReturn(true); + this.filter.doFilter(this.request, this.response, new MockFilterChain()); + verify(this.request, times(1)).getAttribute(FILTER_APPLIED); + verifyNoInteractions(this.repository, this.response); } @Test - void shouldNotFilterErrorDispatchWhenOverridden() { - this.filter.setShouldNotFilterErrorDispatch(true); - assertThat(this.filter.shouldNotFilterErrorDispatch()).isTrue(); + void doFilterWhenNotAppliedThenSetsAndRemovesAttribute() throws Exception { + given(this.repository.loadDeferredContext(this.requestArg.capture())).willReturn( + new SupplierDeferredSecurityContext(SecurityContextHolder::createEmptyContext, this.strategy)); + + this.filter.doFilter(this.request, this.response, new MockFilterChain()); + + InOrder inOrder = inOrder(this.request, this.repository); + inOrder.verify(this.request).setAttribute(FILTER_APPLIED, true); + inOrder.verify(this.repository).loadDeferredContext(this.request); + inOrder.verify(this.request).removeAttribute(FILTER_APPLIED); + } + + @ParameterizedTest + @EnumSource(DispatcherType.class) + void doFilterWhenAnyDispatcherTypeThenFilter(DispatcherType dispatcherType) throws Exception { + lenient().when(this.request.getDispatcherType()).thenReturn(dispatcherType); + Authentication authentication = TestAuthentication.authenticatedUser(); + SecurityContext expectedContext = new SecurityContextImpl(authentication); + given(this.repository.loadDeferredContext(this.requestArg.capture())) + .willReturn(new SupplierDeferredSecurityContext(() -> expectedContext, this.strategy)); + FilterChain filterChain = (request, response) -> assertThat(SecurityContextHolder.getContext()) + .isEqualTo(expectedContext); + + this.filter.doFilter(this.request, this.response, filterChain); } }