diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java
index 9ef84dcef2..9a489c6f96 100644
--- a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java
+++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java
@@ -35,6 +35,8 @@ import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
import org.springframework.web.filter.OncePerRequestFilter;
+import static java.lang.Boolean.TRUE;
+
/**
*
* Applies
@@ -63,6 +65,16 @@ public final class CsrfFilter extends OncePerRequestFilter {
*/
public static final RequestMatcher DEFAULT_CSRF_MATCHER = new DefaultRequiresCsrfMatcher();
+ /**
+ * The attribute name to use when marking a given request as one that should not be filtered.
+ *
+ * To use, set the attribute on your {@link HttpServletRequest}:
+ *
+ * CsrfFilter.skipRequest(request);
+ *
+ */
+ private static final String SHOULD_NOT_FILTER = "SHOULD_NOT_FILTER" + CsrfFilter.class.getName();
+
private final Log logger = LogFactory.getLog(getClass());
private final CsrfTokenRepository tokenRepository;
private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER;
@@ -73,6 +85,11 @@ public final class CsrfFilter extends OncePerRequestFilter {
this.tokenRepository = csrfTokenRepository;
}
+ @Override
+ protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
+ return TRUE.equals(request.getAttribute(SHOULD_NOT_FILTER));
+ }
+
/*
* (non-Javadoc)
*
@@ -124,6 +141,10 @@ public final class CsrfFilter extends OncePerRequestFilter {
filterChain.doFilter(request, response);
}
+ public static void skipRequest(HttpServletRequest request) {
+ request.setAttribute(SHOULD_NOT_FILTER, TRUE);
+ }
+
/**
* Specifies a {@link RequestMatcher} that is used to determine if CSRF protection
* should be applied. If the {@link RequestMatcher} returns true for a given request,
diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java
index a74fc3384d..111306c3fd 100644
--- a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java
+++ b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java
@@ -32,6 +32,8 @@ import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
+import static java.lang.Boolean.TRUE;
+
/**
*
* Applies
@@ -60,6 +62,16 @@ import org.springframework.web.server.WebFilterChain;
public class CsrfWebFilter implements WebFilter {
public static final ServerWebExchangeMatcher DEFAULT_CSRF_MATCHER = new DefaultRequireCsrfProtectionMatcher();
+ /**
+ * The attribute name to use when marking a given request as one that should not be filtered.
+ *
+ * To use, set the attribute on your {@link ServerWebExchange}:
+ *
+ * CsrfWebFilter.skipExchange(exchange);
+ *
+ */
+ private static final String SHOULD_NOT_FILTER = "SHOULD_NOT_FILTER" + CsrfWebFilter.class.getName();
+
private ServerWebExchangeMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER;
private ServerCsrfTokenRepository csrfTokenRepository = new WebSessionServerCsrfTokenRepository();
@@ -86,6 +98,10 @@ public class CsrfWebFilter implements WebFilter {
@Override
public Mono filter(ServerWebExchange exchange, WebFilterChain chain) {
+ if (TRUE.equals(exchange.getAttribute(SHOULD_NOT_FILTER))) {
+ return chain.filter(exchange).then(Mono.empty());
+ }
+
return this.requireCsrfProtectionMatcher.matches(exchange)
.filter( matchResult -> matchResult.isMatch())
.filter( matchResult -> !exchange.getAttributes().containsKey(CsrfToken.class.getName()))
@@ -96,6 +112,10 @@ public class CsrfWebFilter implements WebFilter {
.handle(exchange, e));
}
+ public static void skipExchange(ServerWebExchange exchange) {
+ exchange.getAttributes().put(SHOULD_NOT_FILTER, TRUE);
+ }
+
private Mono validateToken(ServerWebExchange exchange) {
return this.csrfTokenRepository.loadToken(exchange)
.switchIfEmpty(Mono.defer(() -> Mono.error(new CsrfException("CSRF Token has been associated to this client"))))
diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java
index a8b18f52f9..4129201cb4 100644
--- a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java
+++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java
@@ -31,6 +31,7 @@ import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
+import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.web.access.AccessDeniedHandler;
@@ -39,6 +40,8 @@ import org.springframework.security.web.util.matcher.RequestMatcher;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.lenient;
+import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@@ -390,6 +393,22 @@ public class CsrfFilterTests {
verifyZeroInteractions(this.filterChain);
}
+ @Test
+ public void doFilterWhenSkipRequestInvokedThenSkips()
+ throws Exception {
+
+ CsrfTokenRepository repository = mock(CsrfTokenRepository.class);
+ CsrfFilter filter = new CsrfFilter(repository);
+
+ lenient().when(repository.loadToken(any(HttpServletRequest.class))).thenReturn(this.token);
+
+ MockHttpServletRequest request = new MockHttpServletRequest();
+ CsrfFilter.skipRequest(request);
+ filter.doFilter(request, new MockHttpServletResponse(), new MockFilterChain());
+
+ verifyZeroInteractions(repository);
+ }
+
@Test(expected = IllegalArgumentException.class)
public void setRequireCsrfProtectionMatcherNull() {
this.filter.setRequireCsrfProtectionMatcher(null);
diff --git a/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java b/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java
index 52759f507b..1101ddbba9 100644
--- a/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java
+++ b/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java
@@ -20,19 +20,24 @@ import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
-import org.springframework.http.HttpStatus;
-import org.springframework.http.MediaType;
-import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
-import org.springframework.mock.web.server.MockServerWebExchange;
-import org.springframework.web.server.WebFilterChain;
-import org.springframework.web.server.WebSession;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import reactor.test.publisher.PublisherProbe;
+import org.springframework.http.HttpStatus;
+import org.springframework.http.MediaType;
+import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
+import org.springframework.mock.web.server.MockServerWebExchange;
+import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
+import org.springframework.web.server.WebFilterChain;
+import org.springframework.web.server.WebSession;
+
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
+import static org.springframework.mock.web.server.MockServerWebExchange.from;
/**
* @author Rob Winch
@@ -49,10 +54,10 @@ public class CsrfWebFilterTests {
private CsrfWebFilter csrfFilter = new CsrfWebFilter();
- private MockServerWebExchange get = MockServerWebExchange.from(
+ private MockServerWebExchange get = from(
MockServerHttpRequest.get("/"));
- private MockServerWebExchange post = MockServerWebExchange.from(
+ private MockServerWebExchange post = from(
MockServerHttpRequest.post("/"));
@Test
@@ -104,7 +109,7 @@ public class CsrfWebFilterTests {
this.csrfFilter.setCsrfTokenRepository(this.repository);
when(this.repository.loadToken(any()))
.thenReturn(Mono.just(this.token));
- this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
+ this.post = from(MockServerHttpRequest.post("/")
.body(this.token.getParameterName() + "="+this.token.getToken()+"INVALID"));
Mono result = this.csrfFilter.filter(this.post, this.chain);
@@ -125,7 +130,7 @@ public class CsrfWebFilterTests {
.thenReturn(Mono.just(this.token));
when(this.repository.generateToken(any()))
.thenReturn(Mono.just(this.token));
- this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
+ this.post = from(MockServerHttpRequest.post("/")
.contentType(MediaType.APPLICATION_FORM_URLENCODED)
.body(this.token.getParameterName() + "="+this.token.getToken()));
@@ -142,7 +147,7 @@ public class CsrfWebFilterTests {
this.csrfFilter.setCsrfTokenRepository(this.repository);
when(this.repository.loadToken(any()))
.thenReturn(Mono.just(this.token));
- this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
+ this.post = from(MockServerHttpRequest.post("/")
.header(this.token.getHeaderName(), this.token.getToken()+"INVALID"));
Mono result = this.csrfFilter.filter(this.post, this.chain);
@@ -163,7 +168,7 @@ public class CsrfWebFilterTests {
.thenReturn(Mono.just(this.token));
when(this.repository.generateToken(any()))
.thenReturn(Mono.just(this.token));
- this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
+ this.post = from(MockServerHttpRequest.post("/")
.header(this.token.getHeaderName(), this.token.getToken()));
Mono result = this.csrfFilter.filter(this.post, this.chain);
@@ -173,4 +178,19 @@ public class CsrfWebFilterTests {
chainResult.assertWasSubscribed();
}
+
+ @Test
+ public void doFilterWhenSkipExchangeInvokedThenSkips() {
+ PublisherProbe chainResult = PublisherProbe.empty();
+ when(this.chain.filter(any())).thenReturn(chainResult.mono());
+
+ ServerWebExchangeMatcher matcher = mock(ServerWebExchangeMatcher.class);
+ this.csrfFilter.setRequireCsrfProtectionMatcher(matcher);
+
+ MockServerWebExchange exchange = from(MockServerHttpRequest.post("/post").build());
+ CsrfWebFilter.skipExchange(exchange);
+ this.csrfFilter.filter(exchange, this.chain).block();
+
+ verifyZeroInteractions(matcher);
+ }
}