Apply SecurityContextHolderFilter to all dispatcher types
Closes gh-11962
This commit is contained in:
parent
88d50a531b
commit
99d6d21554
|
@ -21,6 +21,8 @@ import java.util.function.Supplier;
|
|||
|
||||
import javax.servlet.FilterChain;
|
||||
import javax.servlet.ServletException;
|
||||
import javax.servlet.ServletRequest;
|
||||
import javax.servlet.ServletResponse;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
|
@ -28,7 +30,7 @@ import org.springframework.security.core.context.SecurityContext;
|
|||
import org.springframework.security.core.context.SecurityContextHolder;
|
||||
import org.springframework.security.core.context.SecurityContextHolderStrategy;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.web.filter.OncePerRequestFilter;
|
||||
import org.springframework.web.filter.GenericFilterBean;
|
||||
|
||||
/**
|
||||
* A {@link javax.servlet.Filter} that uses the {@link SecurityContextRepository} to
|
||||
|
@ -40,17 +42,18 @@ import org.springframework.web.filter.OncePerRequestFilter;
|
|||
* mechanisms to choose individually if authentication should be persisted.
|
||||
*
|
||||
* @author Rob Winch
|
||||
* @author Marcus da Coregio
|
||||
* @since 5.7
|
||||
*/
|
||||
public class SecurityContextHolderFilter extends OncePerRequestFilter {
|
||||
public class SecurityContextHolderFilter extends GenericFilterBean {
|
||||
|
||||
private static final String FILTER_APPLIED = SecurityContextHolderFilter.class.getName() + ".APPLIED";
|
||||
|
||||
private final SecurityContextRepository securityContextRepository;
|
||||
|
||||
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
|
||||
.getContextHolderStrategy();
|
||||
|
||||
private boolean shouldNotFilterErrorDispatch;
|
||||
|
||||
/**
|
||||
* Creates a new instance.
|
||||
* @param securityContextRepository the repository to use. Cannot be null.
|
||||
|
@ -61,23 +64,29 @@ public class SecurityContextHolderFilter extends OncePerRequestFilter {
|
|||
}
|
||||
|
||||
@Override
|
||||
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
|
||||
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
|
||||
throws IOException, ServletException {
|
||||
doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
|
||||
}
|
||||
|
||||
private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
|
||||
throws ServletException, IOException {
|
||||
if (request.getAttribute(FILTER_APPLIED) != null) {
|
||||
chain.doFilter(request, response);
|
||||
return;
|
||||
}
|
||||
request.setAttribute(FILTER_APPLIED, Boolean.TRUE);
|
||||
Supplier<SecurityContext> deferredContext = this.securityContextRepository.loadDeferredContext(request);
|
||||
try {
|
||||
this.securityContextHolderStrategy.setDeferredContext(deferredContext);
|
||||
filterChain.doFilter(request, response);
|
||||
chain.doFilter(request, response);
|
||||
}
|
||||
finally {
|
||||
this.securityContextHolderStrategy.clearContext();
|
||||
request.removeAttribute(FILTER_APPLIED);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean shouldNotFilterErrorDispatch() {
|
||||
return this.shouldNotFilterErrorDispatch;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
|
||||
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
|
||||
|
@ -89,13 +98,4 @@ public class SecurityContextHolderFilter extends OncePerRequestFilter {
|
|||
this.securityContextHolderStrategy = securityContextHolderStrategy;
|
||||
}
|
||||
|
||||
/**
|
||||
* Disables {@link SecurityContextHolderFilter} for error dispatch.
|
||||
* @param shouldNotFilterErrorDispatch if the Filter should be disabled for error
|
||||
* dispatch. Default is false.
|
||||
*/
|
||||
public void setShouldNotFilterErrorDispatch(boolean shouldNotFilterErrorDispatch) {
|
||||
this.shouldNotFilterErrorDispatch = shouldNotFilterErrorDispatch;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.springframework.security.web.context;
|
|||
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import javax.servlet.DispatcherType;
|
||||
import javax.servlet.FilterChain;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
@ -26,11 +27,15 @@ import org.junit.jupiter.api.AfterEach;
|
|||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.EnumSource;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Captor;
|
||||
import org.mockito.InOrder;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
import org.springframework.mock.web.MockFilterChain;
|
||||
import org.springframework.security.authentication.TestAuthentication;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.core.context.SecurityContext;
|
||||
|
@ -40,11 +45,17 @@ import org.springframework.security.core.context.SecurityContextImpl;
|
|||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.BDDMockito.given;
|
||||
import static org.mockito.Mockito.inOrder;
|
||||
import static org.mockito.Mockito.lenient;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoInteractions;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class SecurityContextHolderFilterTests {
|
||||
|
||||
private static final String FILTER_APPLIED = "org.springframework.security.web.context.SecurityContextHolderFilter.APPLIED";
|
||||
|
||||
@Mock
|
||||
private SecurityContextRepository repository;
|
||||
|
||||
|
@ -105,14 +116,38 @@ class SecurityContextHolderFilterTests {
|
|||
}
|
||||
|
||||
@Test
|
||||
void shouldNotFilterErrorDispatchWhenDefault() {
|
||||
assertThat(this.filter.shouldNotFilterErrorDispatch()).isFalse();
|
||||
void doFilterWhenFilterAppliedThenDoNothing() throws Exception {
|
||||
given(this.request.getAttribute(FILTER_APPLIED)).willReturn(true);
|
||||
this.filter.doFilter(this.request, this.response, new MockFilterChain());
|
||||
verify(this.request, times(1)).getAttribute(FILTER_APPLIED);
|
||||
verifyNoInteractions(this.repository, this.response);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldNotFilterErrorDispatchWhenOverridden() {
|
||||
this.filter.setShouldNotFilterErrorDispatch(true);
|
||||
assertThat(this.filter.shouldNotFilterErrorDispatch()).isTrue();
|
||||
void doFilterWhenNotAppliedThenSetsAndRemovesAttribute() throws Exception {
|
||||
given(this.repository.loadDeferredContext(this.requestArg.capture())).willReturn(
|
||||
new SupplierDeferredSecurityContext(SecurityContextHolder::createEmptyContext, this.strategy));
|
||||
|
||||
this.filter.doFilter(this.request, this.response, new MockFilterChain());
|
||||
|
||||
InOrder inOrder = inOrder(this.request, this.repository);
|
||||
inOrder.verify(this.request).setAttribute(FILTER_APPLIED, true);
|
||||
inOrder.verify(this.repository).loadDeferredContext(this.request);
|
||||
inOrder.verify(this.request).removeAttribute(FILTER_APPLIED);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(DispatcherType.class)
|
||||
void doFilterWhenAnyDispatcherTypeThenFilter(DispatcherType dispatcherType) throws Exception {
|
||||
lenient().when(this.request.getDispatcherType()).thenReturn(dispatcherType);
|
||||
Authentication authentication = TestAuthentication.authenticatedUser();
|
||||
SecurityContext expectedContext = new SecurityContextImpl(authentication);
|
||||
given(this.repository.loadDeferredContext(this.requestArg.capture()))
|
||||
.willReturn(new SupplierDeferredSecurityContext(() -> expectedContext, this.strategy));
|
||||
FilterChain filterChain = (request, response) -> assertThat(SecurityContextHolder.getContext())
|
||||
.isEqualTo(expectedContext);
|
||||
|
||||
this.filter.doFilter(this.request, this.response, filterChain);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue