diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java index 9ef84dcef2..9a489c6f96 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java @@ -35,6 +35,8 @@ import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; import org.springframework.web.filter.OncePerRequestFilter; +import static java.lang.Boolean.TRUE; + /** *

* Applies @@ -63,6 +65,16 @@ public final class CsrfFilter extends OncePerRequestFilter { */ public static final RequestMatcher DEFAULT_CSRF_MATCHER = new DefaultRequiresCsrfMatcher(); + /** + * The attribute name to use when marking a given request as one that should not be filtered. + * + * To use, set the attribute on your {@link HttpServletRequest}: + *

+	 * 	CsrfFilter.skipRequest(request);
+	 * 
+ */ + private static final String SHOULD_NOT_FILTER = "SHOULD_NOT_FILTER" + CsrfFilter.class.getName(); + private final Log logger = LogFactory.getLog(getClass()); private final CsrfTokenRepository tokenRepository; private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER; @@ -73,6 +85,11 @@ public final class CsrfFilter extends OncePerRequestFilter { this.tokenRepository = csrfTokenRepository; } + @Override + protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException { + return TRUE.equals(request.getAttribute(SHOULD_NOT_FILTER)); + } + /* * (non-Javadoc) * @@ -124,6 +141,10 @@ public final class CsrfFilter extends OncePerRequestFilter { filterChain.doFilter(request, response); } + public static void skipRequest(HttpServletRequest request) { + request.setAttribute(SHOULD_NOT_FILTER, TRUE); + } + /** * Specifies a {@link RequestMatcher} that is used to determine if CSRF protection * should be applied. If the {@link RequestMatcher} returns true for a given request, diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java index a74fc3384d..111306c3fd 100644 --- a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java @@ -32,6 +32,8 @@ import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; +import static java.lang.Boolean.TRUE; + /** *

* Applies @@ -60,6 +62,16 @@ import org.springframework.web.server.WebFilterChain; public class CsrfWebFilter implements WebFilter { public static final ServerWebExchangeMatcher DEFAULT_CSRF_MATCHER = new DefaultRequireCsrfProtectionMatcher(); + /** + * The attribute name to use when marking a given request as one that should not be filtered. + * + * To use, set the attribute on your {@link ServerWebExchange}: + *

+	 * 	CsrfWebFilter.skipExchange(exchange);
+	 * 
+ */ + private static final String SHOULD_NOT_FILTER = "SHOULD_NOT_FILTER" + CsrfWebFilter.class.getName(); + private ServerWebExchangeMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER; private ServerCsrfTokenRepository csrfTokenRepository = new WebSessionServerCsrfTokenRepository(); @@ -86,6 +98,10 @@ public class CsrfWebFilter implements WebFilter { @Override public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { + if (TRUE.equals(exchange.getAttribute(SHOULD_NOT_FILTER))) { + return chain.filter(exchange).then(Mono.empty()); + } + return this.requireCsrfProtectionMatcher.matches(exchange) .filter( matchResult -> matchResult.isMatch()) .filter( matchResult -> !exchange.getAttributes().containsKey(CsrfToken.class.getName())) @@ -96,6 +112,10 @@ public class CsrfWebFilter implements WebFilter { .handle(exchange, e)); } + public static void skipExchange(ServerWebExchange exchange) { + exchange.getAttributes().put(SHOULD_NOT_FILTER, TRUE); + } + private Mono validateToken(ServerWebExchange exchange) { return this.csrfTokenRepository.loadToken(exchange) .switchIfEmpty(Mono.defer(() -> Mono.error(new CsrfException("CSRF Token has been associated to this client")))) diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java index a8b18f52f9..4129201cb4 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java @@ -31,6 +31,7 @@ import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.web.access.AccessDeniedHandler; @@ -39,6 +40,8 @@ import org.springframework.security.web.util.matcher.RequestMatcher; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -390,6 +393,22 @@ public class CsrfFilterTests { verifyZeroInteractions(this.filterChain); } + @Test + public void doFilterWhenSkipRequestInvokedThenSkips() + throws Exception { + + CsrfTokenRepository repository = mock(CsrfTokenRepository.class); + CsrfFilter filter = new CsrfFilter(repository); + + lenient().when(repository.loadToken(any(HttpServletRequest.class))).thenReturn(this.token); + + MockHttpServletRequest request = new MockHttpServletRequest(); + CsrfFilter.skipRequest(request); + filter.doFilter(request, new MockHttpServletResponse(), new MockFilterChain()); + + verifyZeroInteractions(repository); + } + @Test(expected = IllegalArgumentException.class) public void setRequireCsrfProtectionMatcherNull() { this.filter.setRequireCsrfProtectionMatcher(null); diff --git a/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java b/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java index 52759f507b..1101ddbba9 100644 --- a/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java @@ -20,19 +20,24 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; -import org.springframework.http.HttpStatus; -import org.springframework.http.MediaType; -import org.springframework.mock.http.server.reactive.MockServerHttpRequest; -import org.springframework.mock.web.server.MockServerWebExchange; -import org.springframework.web.server.WebFilterChain; -import org.springframework.web.server.WebSession; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; import reactor.test.publisher.PublisherProbe; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; +import org.springframework.web.server.WebFilterChain; +import org.springframework.web.server.WebSession; + import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; +import static org.springframework.mock.web.server.MockServerWebExchange.from; /** * @author Rob Winch @@ -49,10 +54,10 @@ public class CsrfWebFilterTests { private CsrfWebFilter csrfFilter = new CsrfWebFilter(); - private MockServerWebExchange get = MockServerWebExchange.from( + private MockServerWebExchange get = from( MockServerHttpRequest.get("/")); - private MockServerWebExchange post = MockServerWebExchange.from( + private MockServerWebExchange post = from( MockServerHttpRequest.post("/")); @Test @@ -104,7 +109,7 @@ public class CsrfWebFilterTests { this.csrfFilter.setCsrfTokenRepository(this.repository); when(this.repository.loadToken(any())) .thenReturn(Mono.just(this.token)); - this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/") + this.post = from(MockServerHttpRequest.post("/") .body(this.token.getParameterName() + "="+this.token.getToken()+"INVALID")); Mono result = this.csrfFilter.filter(this.post, this.chain); @@ -125,7 +130,7 @@ public class CsrfWebFilterTests { .thenReturn(Mono.just(this.token)); when(this.repository.generateToken(any())) .thenReturn(Mono.just(this.token)); - this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/") + this.post = from(MockServerHttpRequest.post("/") .contentType(MediaType.APPLICATION_FORM_URLENCODED) .body(this.token.getParameterName() + "="+this.token.getToken())); @@ -142,7 +147,7 @@ public class CsrfWebFilterTests { this.csrfFilter.setCsrfTokenRepository(this.repository); when(this.repository.loadToken(any())) .thenReturn(Mono.just(this.token)); - this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/") + this.post = from(MockServerHttpRequest.post("/") .header(this.token.getHeaderName(), this.token.getToken()+"INVALID")); Mono result = this.csrfFilter.filter(this.post, this.chain); @@ -163,7 +168,7 @@ public class CsrfWebFilterTests { .thenReturn(Mono.just(this.token)); when(this.repository.generateToken(any())) .thenReturn(Mono.just(this.token)); - this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/") + this.post = from(MockServerHttpRequest.post("/") .header(this.token.getHeaderName(), this.token.getToken())); Mono result = this.csrfFilter.filter(this.post, this.chain); @@ -173,4 +178,19 @@ public class CsrfWebFilterTests { chainResult.assertWasSubscribed(); } + + @Test + public void doFilterWhenSkipExchangeInvokedThenSkips() { + PublisherProbe chainResult = PublisherProbe.empty(); + when(this.chain.filter(any())).thenReturn(chainResult.mono()); + + ServerWebExchangeMatcher matcher = mock(ServerWebExchangeMatcher.class); + this.csrfFilter.setRequireCsrfProtectionMatcher(matcher); + + MockServerWebExchange exchange = from(MockServerHttpRequest.post("/post").build()); + CsrfWebFilter.skipExchange(exchange); + this.csrfFilter.filter(exchange, this.chain).block(); + + verifyZeroInteractions(matcher); + } }