Apply SecurityContextHolderFilter to all dispatcher types

Closes gh-11962
This commit is contained in:
Marcus Da Coregio 2022-11-30 03:28:23 -08:00 committed by Marcus Hert Da Coregio
parent 88d50a531b
commit 99d6d21554
2 changed files with 60 additions and 25 deletions

View File

@ -21,6 +21,8 @@ import java.util.function.Supplier;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@ -28,7 +30,7 @@ import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.util.Assert;
import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.filter.GenericFilterBean;
/**
* A {@link javax.servlet.Filter} that uses the {@link SecurityContextRepository} to
@ -40,17 +42,18 @@ import org.springframework.web.filter.OncePerRequestFilter;
* mechanisms to choose individually if authentication should be persisted.
*
* @author Rob Winch
* @author Marcus da Coregio
* @since 5.7
*/
public class SecurityContextHolderFilter extends OncePerRequestFilter {
public class SecurityContextHolderFilter extends GenericFilterBean {
private static final String FILTER_APPLIED = SecurityContextHolderFilter.class.getName() + ".APPLIED";
private final SecurityContextRepository securityContextRepository;
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private boolean shouldNotFilterErrorDispatch;
/**
* Creates a new instance.
* @param securityContextRepository the repository to use. Cannot be null.
@ -61,23 +64,29 @@ public class SecurityContextHolderFilter extends OncePerRequestFilter {
}
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
}
private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws ServletException, IOException {
if (request.getAttribute(FILTER_APPLIED) != null) {
chain.doFilter(request, response);
return;
}
request.setAttribute(FILTER_APPLIED, Boolean.TRUE);
Supplier<SecurityContext> deferredContext = this.securityContextRepository.loadDeferredContext(request);
try {
this.securityContextHolderStrategy.setDeferredContext(deferredContext);
filterChain.doFilter(request, response);
chain.doFilter(request, response);
}
finally {
this.securityContextHolderStrategy.clearContext();
request.removeAttribute(FILTER_APPLIED);
}
}
@Override
protected boolean shouldNotFilterErrorDispatch() {
return this.shouldNotFilterErrorDispatch;
}
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
@ -89,13 +98,4 @@ public class SecurityContextHolderFilter extends OncePerRequestFilter {
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
/**
* Disables {@link SecurityContextHolderFilter} for error dispatch.
* @param shouldNotFilterErrorDispatch if the Filter should be disabled for error
* dispatch. Default is false.
*/
public void setShouldNotFilterErrorDispatch(boolean shouldNotFilterErrorDispatch) {
this.shouldNotFilterErrorDispatch = shouldNotFilterErrorDispatch;
}
}

View File

@ -18,6 +18,7 @@ package org.springframework.security.web.context;
import java.util.function.Supplier;
import javax.servlet.DispatcherType;
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@ -26,11 +27,15 @@ import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.InOrder;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.mock.web.MockFilterChain;
import org.springframework.security.authentication.TestAuthentication;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
@ -40,11 +45,17 @@ import org.springframework.security.core.context.SecurityContextImpl;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
@ExtendWith(MockitoExtension.class)
class SecurityContextHolderFilterTests {
private static final String FILTER_APPLIED = "org.springframework.security.web.context.SecurityContextHolderFilter.APPLIED";
@Mock
private SecurityContextRepository repository;
@ -105,14 +116,38 @@ class SecurityContextHolderFilterTests {
}
@Test
void shouldNotFilterErrorDispatchWhenDefault() {
assertThat(this.filter.shouldNotFilterErrorDispatch()).isFalse();
void doFilterWhenFilterAppliedThenDoNothing() throws Exception {
given(this.request.getAttribute(FILTER_APPLIED)).willReturn(true);
this.filter.doFilter(this.request, this.response, new MockFilterChain());
verify(this.request, times(1)).getAttribute(FILTER_APPLIED);
verifyNoInteractions(this.repository, this.response);
}
@Test
void shouldNotFilterErrorDispatchWhenOverridden() {
this.filter.setShouldNotFilterErrorDispatch(true);
assertThat(this.filter.shouldNotFilterErrorDispatch()).isTrue();
void doFilterWhenNotAppliedThenSetsAndRemovesAttribute() throws Exception {
given(this.repository.loadDeferredContext(this.requestArg.capture())).willReturn(
new SupplierDeferredSecurityContext(SecurityContextHolder::createEmptyContext, this.strategy));
this.filter.doFilter(this.request, this.response, new MockFilterChain());
InOrder inOrder = inOrder(this.request, this.repository);
inOrder.verify(this.request).setAttribute(FILTER_APPLIED, true);
inOrder.verify(this.repository).loadDeferredContext(this.request);
inOrder.verify(this.request).removeAttribute(FILTER_APPLIED);
}
@ParameterizedTest
@EnumSource(DispatcherType.class)
void doFilterWhenAnyDispatcherTypeThenFilter(DispatcherType dispatcherType) throws Exception {
lenient().when(this.request.getDispatcherType()).thenReturn(dispatcherType);
Authentication authentication = TestAuthentication.authenticatedUser();
SecurityContext expectedContext = new SecurityContextImpl(authentication);
given(this.repository.loadDeferredContext(this.requestArg.capture()))
.willReturn(new SupplierDeferredSecurityContext(() -> expectedContext, this.strategy));
FilterChain filterChain = (request, response) -> assertThat(SecurityContextHolder.getContext())
.isEqualTo(expectedContext);
this.filter.doFilter(this.request, this.response, filterChain);
}
}