Add Request-level CSRF Skip

Fixes gh-7367
This commit is contained in:
Josh Cummings 2019-09-05 05:22:35 -06:00
parent 9920cb41d1
commit aa12748c9b
4 changed files with 92 additions and 12 deletions

View File

@ -35,6 +35,8 @@ import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
import org.springframework.web.filter.OncePerRequestFilter;
import static java.lang.Boolean.TRUE;
/**
* <p>
* Applies
@ -63,6 +65,16 @@ public final class CsrfFilter extends OncePerRequestFilter {
*/
public static final RequestMatcher DEFAULT_CSRF_MATCHER = new DefaultRequiresCsrfMatcher();
/**
* The attribute name to use when marking a given request as one that should not be filtered.
*
* To use, set the attribute on your {@link HttpServletRequest}:
* <pre>
* CsrfFilter.skipRequest(request);
* </pre>
*/
private static final String SHOULD_NOT_FILTER = "SHOULD_NOT_FILTER" + CsrfFilter.class.getName();
private final Log logger = LogFactory.getLog(getClass());
private final CsrfTokenRepository tokenRepository;
private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER;
@ -73,6 +85,11 @@ public final class CsrfFilter extends OncePerRequestFilter {
this.tokenRepository = csrfTokenRepository;
}
@Override
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
return TRUE.equals(request.getAttribute(SHOULD_NOT_FILTER));
}
/*
* (non-Javadoc)
*
@ -124,6 +141,10 @@ public final class CsrfFilter extends OncePerRequestFilter {
filterChain.doFilter(request, response);
}
public static void skipRequest(HttpServletRequest request) {
request.setAttribute(SHOULD_NOT_FILTER, TRUE);
}
/**
* Specifies a {@link RequestMatcher} that is used to determine if CSRF protection
* should be applied. If the {@link RequestMatcher} returns true for a given request,

View File

@ -32,6 +32,8 @@ import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import static java.lang.Boolean.TRUE;
/**
* <p>
* Applies
@ -60,6 +62,16 @@ import org.springframework.web.server.WebFilterChain;
public class CsrfWebFilter implements WebFilter {
public static final ServerWebExchangeMatcher DEFAULT_CSRF_MATCHER = new DefaultRequireCsrfProtectionMatcher();
/**
* The attribute name to use when marking a given request as one that should not be filtered.
*
* To use, set the attribute on your {@link ServerWebExchange}:
* <pre>
* CsrfWebFilter.skipExchange(exchange);
* </pre>
*/
private static final String SHOULD_NOT_FILTER = "SHOULD_NOT_FILTER" + CsrfWebFilter.class.getName();
private ServerWebExchangeMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER;
private ServerCsrfTokenRepository csrfTokenRepository = new WebSessionServerCsrfTokenRepository();
@ -86,6 +98,10 @@ public class CsrfWebFilter implements WebFilter {
@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
if (TRUE.equals(exchange.getAttribute(SHOULD_NOT_FILTER))) {
return chain.filter(exchange).then(Mono.empty());
}
return this.requireCsrfProtectionMatcher.matches(exchange)
.filter( matchResult -> matchResult.isMatch())
.filter( matchResult -> !exchange.getAttributes().containsKey(CsrfToken.class.getName()))
@ -96,6 +112,10 @@ public class CsrfWebFilter implements WebFilter {
.handle(exchange, e));
}
public static void skipExchange(ServerWebExchange exchange) {
exchange.getAttributes().put(SHOULD_NOT_FILTER, TRUE);
}
private Mono<Void> validateToken(ServerWebExchange exchange) {
return this.csrfTokenRepository.loadToken(exchange)
.switchIfEmpty(Mono.defer(() -> Mono.error(new CsrfException("CSRF Token has been associated to this client"))))

View File

@ -31,6 +31,7 @@ import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.web.access.AccessDeniedHandler;
@ -39,6 +40,8 @@ import org.springframework.security.web.util.matcher.RequestMatcher;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@ -390,6 +393,22 @@ public class CsrfFilterTests {
verifyZeroInteractions(this.filterChain);
}
@Test
public void doFilterWhenSkipRequestInvokedThenSkips()
throws Exception {
CsrfTokenRepository repository = mock(CsrfTokenRepository.class);
CsrfFilter filter = new CsrfFilter(repository);
lenient().when(repository.loadToken(any(HttpServletRequest.class))).thenReturn(this.token);
MockHttpServletRequest request = new MockHttpServletRequest();
CsrfFilter.skipRequest(request);
filter.doFilter(request, new MockHttpServletResponse(), new MockFilterChain());
verifyZeroInteractions(repository);
}
@Test(expected = IllegalArgumentException.class)
public void setRequireCsrfProtectionMatcherNull() {
this.filter.setRequireCsrfProtectionMatcher(null);

View File

@ -20,19 +20,24 @@ import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.web.server.WebFilterChain;
import org.springframework.web.server.WebSession;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import reactor.test.publisher.PublisherProbe;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.web.server.WebFilterChain;
import org.springframework.web.server.WebSession;
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import static org.springframework.mock.web.server.MockServerWebExchange.from;
/**
* @author Rob Winch
@ -49,10 +54,10 @@ public class CsrfWebFilterTests {
private CsrfWebFilter csrfFilter = new CsrfWebFilter();
private MockServerWebExchange get = MockServerWebExchange.from(
private MockServerWebExchange get = from(
MockServerHttpRequest.get("/"));
private MockServerWebExchange post = MockServerWebExchange.from(
private MockServerWebExchange post = from(
MockServerHttpRequest.post("/"));
@Test
@ -104,7 +109,7 @@ public class CsrfWebFilterTests {
this.csrfFilter.setCsrfTokenRepository(this.repository);
when(this.repository.loadToken(any()))
.thenReturn(Mono.just(this.token));
this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
this.post = from(MockServerHttpRequest.post("/")
.body(this.token.getParameterName() + "="+this.token.getToken()+"INVALID"));
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
@ -125,7 +130,7 @@ public class CsrfWebFilterTests {
.thenReturn(Mono.just(this.token));
when(this.repository.generateToken(any()))
.thenReturn(Mono.just(this.token));
this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
this.post = from(MockServerHttpRequest.post("/")
.contentType(MediaType.APPLICATION_FORM_URLENCODED)
.body(this.token.getParameterName() + "="+this.token.getToken()));
@ -142,7 +147,7 @@ public class CsrfWebFilterTests {
this.csrfFilter.setCsrfTokenRepository(this.repository);
when(this.repository.loadToken(any()))
.thenReturn(Mono.just(this.token));
this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
this.post = from(MockServerHttpRequest.post("/")
.header(this.token.getHeaderName(), this.token.getToken()+"INVALID"));
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
@ -163,7 +168,7 @@ public class CsrfWebFilterTests {
.thenReturn(Mono.just(this.token));
when(this.repository.generateToken(any()))
.thenReturn(Mono.just(this.token));
this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
this.post = from(MockServerHttpRequest.post("/")
.header(this.token.getHeaderName(), this.token.getToken()));
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
@ -173,4 +178,19 @@ public class CsrfWebFilterTests {
chainResult.assertWasSubscribed();
}
@Test
public void doFilterWhenSkipExchangeInvokedThenSkips() {
PublisherProbe<Void> chainResult = PublisherProbe.empty();
when(this.chain.filter(any())).thenReturn(chainResult.mono());
ServerWebExchangeMatcher matcher = mock(ServerWebExchangeMatcher.class);
this.csrfFilter.setRequireCsrfProtectionMatcher(matcher);
MockServerWebExchange exchange = from(MockServerHttpRequest.post("/post").build());
CsrfWebFilter.skipExchange(exchange);
this.csrfFilter.filter(exchange, this.chain).block();
verifyZeroInteractions(matcher);
}
}