diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java
index ccfe5fd2bb..cc7d435f3d 100644
--- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java
+++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java
@@ -48,6 +48,7 @@ import org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler;
import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.csrf.DeferredCsrfToken;
+import org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler;
import org.springframework.security.web.firewall.StrictHttpFirewall;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
@@ -62,6 +63,7 @@ import org.springframework.web.servlet.support.RequestDataValueProcessor;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.not;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.BDDMockito.given;
@@ -423,7 +425,7 @@ public class CsrfConfigurerTests {
}
@Test
- public void getLoginWhenCsrfTokenRequestHandlerSetThenRespondsWithNormalCsrfToken() throws Exception {
+ public void getLoginWhenCsrfTokenRequestAttributeHandlerSetThenRespondsWithNormalCsrfToken() throws Exception {
CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class);
CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class)))
@@ -438,7 +440,7 @@ public class CsrfConfigurerTests {
}
@Test
- public void loginWhenCsrfTokenRequestHandlerSetAndNormalCsrfTokenThenSuccess() throws Exception {
+ public void loginWhenCsrfTokenRequestAttributeHandlerSetAndNormalCsrfTokenThenSuccess() throws Exception {
CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class);
given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class)))
@@ -460,6 +462,47 @@ public class CsrfConfigurerTests {
verifyNoMoreInteractions(csrfTokenRepository);
}
+ @Test
+ public void getLoginWhenXorCsrfTokenRequestAttributeHandlerSetThenRespondsWithMaskedCsrfToken() throws Exception {
+ CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class);
+ CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
+ given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class)))
+ .willReturn(new TestDeferredCsrfToken(csrfToken));
+ CsrfTokenRequestHandlerConfig.REPO = csrfTokenRepository;
+ CsrfTokenRequestHandlerConfig.HANDLER = new XorCsrfTokenRequestAttributeHandler();
+ this.spring.register(CsrfTokenRequestHandlerConfig.class, BasicController.class).autowire();
+ this.mvc.perform(get("/login")).andExpect(status().isOk())
+ .andExpect(content().string(not(containsString(csrfToken.getToken()))));
+ verify(csrfTokenRepository).loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class));
+ verifyNoMoreInteractions(csrfTokenRepository);
+ }
+
+ @Test
+ public void loginWhenXorCsrfTokenRequestAttributeHandlerSetAndMaskedCsrfTokenThenSuccess() throws Exception {
+ CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
+ CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class);
+ given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class)))
+ .willReturn(new TestDeferredCsrfToken(csrfToken));
+ CsrfTokenRequestHandlerConfig.REPO = csrfTokenRepository;
+ CsrfTokenRequestHandlerConfig.HANDLER = new XorCsrfTokenRequestAttributeHandler();
+ this.spring.register(CsrfTokenRequestHandlerConfig.class, BasicController.class).autowire();
+
+ MvcResult mvcResult = this.mvc.perform(get("/login")).andReturn();
+ CsrfToken csrfTokenAttribute = (CsrfToken) mvcResult.getRequest().getAttribute(CsrfToken.class.getName());
+
+ // @formatter:off
+ MockHttpServletRequestBuilder loginRequest = post("/login")
+ .header(csrfToken.getHeaderName(), csrfTokenAttribute.getToken())
+ .param("username", "user")
+ .param("password", "password");
+ // @formatter:on
+ this.mvc.perform(loginRequest).andExpect(redirectedUrl("/"));
+ verify(csrfTokenRepository).saveToken(isNull(), any(HttpServletRequest.class), any(HttpServletResponse.class));
+ verify(csrfTokenRepository, times(3)).loadDeferredToken(any(HttpServletRequest.class),
+ any(HttpServletResponse.class));
+ verifyNoMoreInteractions(csrfTokenRepository);
+ }
+
@Configuration
static class AllowHttpMethodsFirewallConfig {
diff --git a/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java b/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java
index e9220895fb..4da7fbc565 100644
--- a/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java
+++ b/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java
@@ -300,6 +300,39 @@ public class CsrfConfigTests {
// @formatter:on
}
+ @Test
+ public void postWhenUsingCsrfAndXorCsrfTokenRequestProcessorThenOk() throws Exception {
+ this.spring.configLocations(this.xml("WithXorCsrfTokenRequestAttributeHandler"), this.xml("shared-controllers"))
+ .autowire();
+ // @formatter:off
+ MvcResult mvcResult = this.mvc.perform(get("/ok"))
+ .andExpect(status().isOk())
+ .andReturn();
+ MockHttpSession session = (MockHttpSession) mvcResult.getRequest().getSession();
+ CsrfToken csrfToken = (CsrfToken) mvcResult.getRequest().getAttribute("_csrf");
+ MockHttpServletRequestBuilder ok = post("/ok")
+ .header(csrfToken.getHeaderName(), csrfToken.getToken())
+ .session(session);
+ this.mvc.perform(ok).andExpect(status().isOk());
+ // @formatter:on
+ }
+
+ @Test
+ public void postWhenUsingCsrfAndXorCsrfTokenRequestProcessorWithRawTokenThenForbidden() throws Exception {
+ this.spring.configLocations(this.xml("WithXorCsrfTokenRequestAttributeHandler"), this.xml("shared-controllers"))
+ .autowire();
+ // @formatter:off
+ MvcResult mvcResult = this.mvc.perform(get("/ok"))
+ .andExpect(status().isOk())
+ .andReturn();
+ MockHttpSession session = (MockHttpSession) mvcResult.getRequest().getSession();
+ MockHttpServletRequestBuilder ok = post("/ok")
+ .with(csrf())
+ .session(session);
+ this.mvc.perform(ok).andExpect(status().isForbidden());
+ // @formatter:on
+ }
+
@Test
public void postWhenHasCsrfTokenButSessionExpiresThenRequestIsCancelledAfterSuccessfulAuthentication()
throws Exception {
diff --git a/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithXorCsrfTokenRequestAttributeHandler.xml b/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithXorCsrfTokenRequestAttributeHandler.xml
new file mode 100644
index 0000000000..beee0c88c7
--- /dev/null
+++ b/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithXorCsrfTokenRequestAttributeHandler.xml
@@ -0,0 +1,32 @@
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/docs/modules/ROOT/pages/servlet/exploits/csrf.adoc b/docs/modules/ROOT/pages/servlet/exploits/csrf.adoc
index 39447e3ea4..39a56ec3c0 100644
--- a/docs/modules/ROOT/pages/servlet/exploits/csrf.adoc
+++ b/docs/modules/ROOT/pages/servlet/exploits/csrf.adoc
@@ -163,13 +163,76 @@ class SecurityConfig {
----
====
+[[servlet-csrf-configure-request-handler]]
+==== Configure CsrfTokenRequestHandler
+
+Spring Security's https://docs.spring.io/spring-security/site/docs/current/api/org/springframework/security/web/csrf/CsrfFilter.html[CsrfFilter] exposes a https://docs.spring.io/spring-security/site/docs/current/api/org/springframework/security/web/csrf/CsrfToken.html[CsrfToken] as an `HttpServletRequest` attribute named `_csrf` with the help of a https://docs.spring.io/spring-security/site/docs/current/api/org/springframework/security/web/csrf/CsrfTokenRequestHandler.html[CsrfTokenRequestHandler].
+The default implementation is `CsrfTokenRequestAttributeHandler`.
+
+An alternate implementation `XorCsrfTokenRequestAttributeHandler` is available to provide protection for BREACH (see https://github.com/spring-projects/spring-security/issues/4001[gh-4001]).
+
+You can configure `XorCsrfTokenRequestAttributeHandler` in XML using the following:
+
+.Configure BREACH protection XML Configuration
+====
+[source,xml]
+----
+
+
+
+
+
+----
+====
+
+You can configure `XorCsrfTokenRequestAttributeHandler` in Java Configuration using the following:
+
+.Configure BREACH protection
+====
+.Java
+[source,java,role="primary"]
+----
+@EnableWebSecurity
+public class WebSecurityConfig {
+
+ @Bean
+ public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
+ http
+ .csrf(csrf -> csrf
+ .csrfTokenRequestHandler(new XorCsrfTokenRequestAttributeHandler())
+ );
+ return http.build();
+ }
+}
+----
+
+.Kotlin
+[source,kotlin,role="secondary"]
+----
+@EnableWebSecurity
+class SecurityConfig {
+
+ @Bean
+ open fun filterChain(http: HttpSecurity): SecurityFilterChain {
+ http {
+ csrf {
+ csrfTokenRequestHandler = XorCsrfTokenRequestAttributeHandler()
+ }
+ }
+ return http.build()
+ }
+}
+----
+====
+
[[servlet-csrf-include]]
=== Include the CSRF Token
In order for the xref:features/exploits/csrf.adoc#csrf-protection-stp[synchronizer token pattern] to protect against CSRF attacks, we must include the actual CSRF token in the HTTP request.
This must be included in a part of the request (i.e. form parameter, HTTP header, etc) that is not automatically included in the HTTP request by the browser.
-Spring Security's https://docs.spring.io/spring-security/site/docs/current/api/org/springframework/security/web/csrf/CsrfFilter.html[CsrfFilter] exposes a https://docs.spring.io/spring-security/site/docs/current/api/org/springframework/security/web/csrf/CsrfToken.html[CsrfToken] as an `HttpServletRequest` attribute named `_csrf`.
+<> that the `CsrfToken` is exposed as a request attribute.
This means that any view technology can access the `CsrfToken` to expose the expected token as either a <> or <>.
Fortunately, there are integrations listed below that make including the token in <> and <> requests even easier.
diff --git a/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java b/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java
new file mode 100644
index 0000000000..160b51d19e
--- /dev/null
+++ b/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java
@@ -0,0 +1,126 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.csrf;
+
+import java.security.SecureRandom;
+import java.util.Base64;
+import java.util.function.Supplier;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
+import org.springframework.security.crypto.codec.Utf8;
+import org.springframework.util.Assert;
+
+/**
+ * An implementation of the {@link CsrfTokenRequestHandler} interface that is capable of
+ * masking the value of the {@link CsrfToken} on each request and resolving the raw token
+ * value from the masked value as either a header or parameter value of the request.
+ *
+ * @author Steve Riesenberg
+ * @since 5.8
+ */
+public final class XorCsrfTokenRequestAttributeHandler extends CsrfTokenRequestAttributeHandler {
+
+ private SecureRandom secureRandom = new SecureRandom();
+
+ /**
+ * Specifies the {@code SecureRandom} used to generate random bytes that are used to
+ * mask the value of the {@link CsrfToken} on each request.
+ * @param secureRandom the {@code SecureRandom} to use to generate random bytes
+ */
+ public void setSecureRandom(SecureRandom secureRandom) {
+ Assert.notNull(secureRandom, "secureRandom cannot be null");
+ this.secureRandom = secureRandom;
+ }
+
+ @Override
+ public void handle(HttpServletRequest request, HttpServletResponse response,
+ Supplier deferredCsrfToken) {
+ Assert.notNull(request, "request cannot be null");
+ Assert.notNull(response, "response cannot be null");
+ Assert.notNull(deferredCsrfToken, "deferredCsrfToken cannot be null");
+ Supplier updatedCsrfToken = deferCsrfTokenUpdate(deferredCsrfToken);
+ super.handle(request, response, updatedCsrfToken);
+ }
+
+ private Supplier deferCsrfTokenUpdate(Supplier csrfTokenSupplier) {
+ return () -> {
+ CsrfToken csrfToken = csrfTokenSupplier.get();
+ Assert.state(csrfToken != null, "csrfToken supplier returned null");
+ String updatedToken = createXoredCsrfToken(this.secureRandom, csrfToken.getToken());
+ return new DefaultCsrfToken(csrfToken.getHeaderName(), csrfToken.getParameterName(), updatedToken);
+ };
+ }
+
+ @Override
+ public String resolveCsrfTokenValue(HttpServletRequest request, CsrfToken csrfToken) {
+ String actualToken = super.resolveCsrfTokenValue(request, csrfToken);
+ return getTokenValue(actualToken, csrfToken.getToken());
+ }
+
+ private static String getTokenValue(String actualToken, String token) {
+ byte[] actualBytes;
+ try {
+ actualBytes = Base64.getUrlDecoder().decode(actualToken);
+ }
+ catch (Exception ex) {
+ return null;
+ }
+
+ byte[] tokenBytes = Utf8.encode(token);
+ int tokenSize = tokenBytes.length;
+ if (actualBytes.length < tokenSize) {
+ return null;
+ }
+
+ // extract token and random bytes
+ int randomBytesSize = actualBytes.length - tokenSize;
+ byte[] xoredCsrf = new byte[tokenSize];
+ byte[] randomBytes = new byte[randomBytesSize];
+
+ System.arraycopy(actualBytes, 0, randomBytes, 0, randomBytesSize);
+ System.arraycopy(actualBytes, randomBytesSize, xoredCsrf, 0, tokenSize);
+
+ byte[] csrfBytes = xorCsrf(randomBytes, xoredCsrf);
+ return Utf8.decode(csrfBytes);
+ }
+
+ private static String createXoredCsrfToken(SecureRandom secureRandom, String token) {
+ byte[] tokenBytes = Utf8.encode(token);
+ byte[] randomBytes = new byte[tokenBytes.length];
+ secureRandom.nextBytes(randomBytes);
+
+ byte[] xoredBytes = xorCsrf(randomBytes, tokenBytes);
+ byte[] combinedBytes = new byte[tokenBytes.length + randomBytes.length];
+ System.arraycopy(randomBytes, 0, combinedBytes, 0, randomBytes.length);
+ System.arraycopy(xoredBytes, 0, combinedBytes, randomBytes.length, xoredBytes.length);
+
+ return Base64.getUrlEncoder().encodeToString(combinedBytes);
+ }
+
+ private static byte[] xorCsrf(byte[] randomBytes, byte[] csrfBytes) {
+ int len = Math.min(randomBytes.length, csrfBytes.length);
+ byte[] xoredCsrf = new byte[len];
+ System.arraycopy(csrfBytes, 0, xoredCsrf, 0, csrfBytes.length);
+ for (int i = 0; i < len; i++) {
+ xoredCsrf[i] ^= randomBytes[i];
+ }
+ return xoredCsrf;
+ }
+
+}
diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java
index 8b0fc1449a..4ad810329a 100644
--- a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java
+++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java
@@ -18,6 +18,7 @@ package org.springframework.security.web.csrf;
import java.io.IOException;
import java.util.Arrays;
+import java.util.Base64;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
@@ -33,6 +34,8 @@ import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.access.AccessDeniedException;
+import org.springframework.security.crypto.codec.Utf8;
import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.util.matcher.RequestMatcher;
@@ -362,6 +365,45 @@ public class CsrfFilterTests {
verify(this.filterChain).doFilter(this.request, this.response);
}
+ @Test
+ public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndValidTokenThenSuccess() throws Exception {
+ given(this.requestMatcher.matches(this.request)).willReturn(false);
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
+ XorCsrfTokenRequestAttributeHandler requestHandler = new XorCsrfTokenRequestAttributeHandler();
+ requestHandler.setCsrfRequestAttributeName(this.token.getParameterName());
+ this.filter.setRequestHandler(requestHandler);
+ this.filter.doFilter(this.request, this.response, this.filterChain);
+ assertThat(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
+ assertThat(this.request.getAttribute(this.token.getParameterName())).isNotNull();
+ verify(this.filterChain).doFilter(this.request, this.response);
+ assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
+
+ CsrfToken csrfTokenAttribute = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
+ byte[] csrfTokenAttributeBytes = Base64.getUrlDecoder().decode(csrfTokenAttribute.getToken());
+ byte[] actualTokenBytes = Utf8.encode(this.token.getToken());
+ // XOR'd token length is 2x due to containing the random bytes
+ assertThat(csrfTokenAttributeBytes).hasSize(actualTokenBytes.length * 2);
+
+ given(this.requestMatcher.matches(this.request)).willReturn(true);
+ this.request.setParameter(this.token.getParameterName(), csrfTokenAttribute.getToken());
+ this.filter.doFilter(this.request, this.response, this.filterChain);
+ verify(this.filterChain, times(2)).doFilter(this.request, this.response);
+ }
+
+ @Test
+ public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndRawTokenThenAccessDeniedException() throws Exception {
+ given(this.requestMatcher.matches(this.request)).willReturn(true);
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
+ XorCsrfTokenRequestAttributeHandler requestHandler = new XorCsrfTokenRequestAttributeHandler();
+ this.filter.setRequestHandler(requestHandler);
+ this.request.setParameter(this.token.getParameterName(), this.token.getToken());
+ this.filter.doFilter(this.request, this.response, this.filterChain);
+ verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(AccessDeniedException.class));
+ verifyNoMoreInteractions(this.filterChain);
+ }
+
@Test
public void setRequireCsrfProtectionMatcherNull() {
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequireCsrfProtectionMatcher(null));
diff --git a/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java b/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java
new file mode 100644
index 0000000000..8b03382451
--- /dev/null
+++ b/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java
@@ -0,0 +1,203 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.csrf;
+
+import java.security.SecureRandom;
+import java.util.Arrays;
+import java.util.Base64;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.stubbing.Answer;
+
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.mock.web.MockHttpServletResponse;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.BDDMockito.willAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+
+/**
+ * Tests for {@link XorCsrfTokenRequestAttributeHandler}.
+ *
+ * @author Steve Riesenberg
+ * @since 5.8
+ */
+public class XorCsrfTokenRequestAttributeHandlerTests {
+
+ private static final byte[] XOR_CSRF_TOKEN_BYTES = new byte[] { 1, 1, 1, 96, 99, 98 };
+
+ private static final String XOR_CSRF_TOKEN_VALUE = Base64.getEncoder().encodeToString(XOR_CSRF_TOKEN_BYTES);
+
+ private MockHttpServletRequest request;
+
+ private MockHttpServletResponse response;
+
+ private CsrfToken token;
+
+ private SecureRandom secureRandom;
+
+ private XorCsrfTokenRequestAttributeHandler handler;
+
+ @BeforeEach
+ public void setup() {
+ this.request = new MockHttpServletRequest();
+ this.response = new MockHttpServletResponse();
+ this.token = new DefaultCsrfToken("headerName", "paramName", "abc");
+ this.secureRandom = mock(SecureRandom.class);
+ this.handler = new XorCsrfTokenRequestAttributeHandler();
+ }
+
+ @Test
+ public void setSecureRandomWhenNullThenThrowsIllegalArgumentException() {
+ // @formatter:off
+ assertThatIllegalArgumentException()
+ .isThrownBy(() -> this.handler.setSecureRandom(null))
+ .withMessage("secureRandom cannot be null");
+ // @formatter:on
+ }
+
+ @Test
+ public void handleWhenRequestIsNullThenThrowsIllegalArgumentException() {
+ // @formatter:off
+ assertThatIllegalArgumentException()
+ .isThrownBy(() -> this.handler.handle(null, this.response, () -> this.token))
+ .withMessage("request cannot be null");
+ // @formatter:on
+ }
+
+ @Test
+ public void handleWhenResponseIsNullThenThrowsIllegalArgumentException() {
+ // @formatter:off
+ assertThatIllegalArgumentException()
+ .isThrownBy(() -> this.handler.handle(this.request, null, () -> this.token))
+ .withMessage("response cannot be null");
+ // @formatter:on
+ }
+
+ @Test
+ public void handleWhenCsrfTokenSupplierIsNullThenThrowsIllegalArgumentException() {
+ // @formatter:off
+ assertThatIllegalArgumentException()
+ .isThrownBy(() -> this.handler.handle(this.request, this.response, null))
+ .withMessage("deferredCsrfToken cannot be null");
+ // @formatter:on
+ }
+
+ @Test
+ public void handleWhenCsrfTokenIsNullThenThrowsIllegalStateException() {
+ // @formatter:off
+ assertThatIllegalStateException()
+ .isThrownBy(() -> this.handler.handle(this.request, this.response, () -> null))
+ .withMessage("csrfToken supplier returned null");
+ // @formatter:on
+ }
+
+ @Test
+ public void handleWhenCsrfRequestAttributeSetThenUsed() {
+ willAnswer(fillByteArray()).given(this.secureRandom).nextBytes(anyByteArray());
+
+ this.handler.setSecureRandom(this.secureRandom);
+ this.handler.setCsrfRequestAttributeName("_csrf");
+ this.handler.handle(this.request, this.response, () -> this.token);
+ assertThat(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
+ assertThat(this.request.getAttribute("_csrf")).isNotNull();
+
+ CsrfToken csrfTokenAttribute = (CsrfToken) this.request.getAttribute("_csrf");
+ assertThat(csrfTokenAttribute.getToken()).isEqualTo(XOR_CSRF_TOKEN_VALUE);
+ }
+
+ @Test
+ public void handleWhenSecureRandomSetThenUsed() {
+ this.handler.setSecureRandom(this.secureRandom);
+ this.handler.handle(this.request, this.response, () -> this.token);
+ verify(this.secureRandom).nextBytes(anyByteArray());
+ verifyNoMoreInteractions(this.secureRandom);
+ }
+
+ @Test
+ public void handleWhenValidParametersThenRequestAttributesSet() {
+ willAnswer(fillByteArray()).given(this.secureRandom).nextBytes(anyByteArray());
+
+ this.handler.setSecureRandom(this.secureRandom);
+ this.handler.handle(this.request, this.response, () -> this.token);
+ verify(this.secureRandom).nextBytes(anyByteArray());
+ assertThat(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
+ assertThat(this.request.getAttribute(this.token.getParameterName())).isNotNull();
+
+ CsrfToken csrfTokenAttribute = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
+ assertThat(csrfTokenAttribute.getToken()).isEqualTo(XOR_CSRF_TOKEN_VALUE);
+ }
+
+ @Test
+ public void resolveCsrfTokenValueWhenRequestIsNullThenThrowsIllegalArgumentException() {
+ assertThatIllegalArgumentException().isThrownBy(() -> this.handler.resolveCsrfTokenValue(null, this.token))
+ .withMessage("request cannot be null");
+ }
+
+ @Test
+ public void resolveCsrfTokenValueWhenCsrfTokenIsNullThenThrowsIllegalArgumentException() {
+ assertThatIllegalArgumentException().isThrownBy(() -> this.handler.resolveCsrfTokenValue(this.request, null))
+ .withMessage("csrfToken cannot be null");
+ }
+
+ @Test
+ public void resolveCsrfTokenValueWhenTokenNotSetThenReturnsNull() {
+ String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token);
+ assertThat(tokenValue).isNull();
+ }
+
+ @Test
+ public void resolveCsrfTokenValueWhenParameterSetThenReturnsTokenValue() {
+ this.request.setParameter(this.token.getParameterName(), XOR_CSRF_TOKEN_VALUE);
+ String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token);
+ assertThat(tokenValue).isEqualTo(this.token.getToken());
+ }
+
+ @Test
+ public void resolveCsrfTokenValueWhenHeaderSetThenReturnsTokenValue() {
+ this.request.addHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE);
+ String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token);
+ assertThat(tokenValue).isEqualTo(this.token.getToken());
+ }
+
+ @Test
+ public void resolveCsrfTokenValueWhenHeaderAndParameterSetThenHeaderIsPreferred() {
+ this.request.addHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE);
+ this.request.setParameter(this.token.getParameterName(), "invalid");
+ String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token);
+ assertThat(tokenValue).isEqualTo(this.token.getToken());
+ }
+
+ private static Answer fillByteArray() {
+ return (invocation) -> {
+ byte[] bytes = invocation.getArgument(0);
+ Arrays.fill(bytes, (byte) 1);
+ return null;
+ };
+ }
+
+ private static byte[] anyByteArray() {
+ return any(byte[].class);
+ }
+
+}