diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java index ccfe5fd2bb..cc7d435f3d 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java @@ -48,6 +48,7 @@ import org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler; import org.springframework.security.web.csrf.CsrfTokenRequestHandler; import org.springframework.security.web.csrf.DefaultCsrfToken; import org.springframework.security.web.csrf.DeferredCsrfToken; +import org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler; import org.springframework.security.web.firewall.StrictHttpFirewall; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; @@ -62,6 +63,7 @@ import org.springframework.web.servlet.support.RequestDataValueProcessor; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.not; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.BDDMockito.given; @@ -423,7 +425,7 @@ public class CsrfConfigurerTests { } @Test - public void getLoginWhenCsrfTokenRequestHandlerSetThenRespondsWithNormalCsrfToken() throws Exception { + public void getLoginWhenCsrfTokenRequestAttributeHandlerSetThenRespondsWithNormalCsrfToken() throws Exception { CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class); CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class))) @@ -438,7 +440,7 @@ public class CsrfConfigurerTests { } @Test - public void loginWhenCsrfTokenRequestHandlerSetAndNormalCsrfTokenThenSuccess() throws Exception { + public void loginWhenCsrfTokenRequestAttributeHandlerSetAndNormalCsrfTokenThenSuccess() throws Exception { CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class); given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class))) @@ -460,6 +462,47 @@ public class CsrfConfigurerTests { verifyNoMoreInteractions(csrfTokenRepository); } + @Test + public void getLoginWhenXorCsrfTokenRequestAttributeHandlerSetThenRespondsWithMaskedCsrfToken() throws Exception { + CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class); + CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); + given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class))) + .willReturn(new TestDeferredCsrfToken(csrfToken)); + CsrfTokenRequestHandlerConfig.REPO = csrfTokenRepository; + CsrfTokenRequestHandlerConfig.HANDLER = new XorCsrfTokenRequestAttributeHandler(); + this.spring.register(CsrfTokenRequestHandlerConfig.class, BasicController.class).autowire(); + this.mvc.perform(get("/login")).andExpect(status().isOk()) + .andExpect(content().string(not(containsString(csrfToken.getToken())))); + verify(csrfTokenRepository).loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class)); + verifyNoMoreInteractions(csrfTokenRepository); + } + + @Test + public void loginWhenXorCsrfTokenRequestAttributeHandlerSetAndMaskedCsrfTokenThenSuccess() throws Exception { + CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); + CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class); + given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class))) + .willReturn(new TestDeferredCsrfToken(csrfToken)); + CsrfTokenRequestHandlerConfig.REPO = csrfTokenRepository; + CsrfTokenRequestHandlerConfig.HANDLER = new XorCsrfTokenRequestAttributeHandler(); + this.spring.register(CsrfTokenRequestHandlerConfig.class, BasicController.class).autowire(); + + MvcResult mvcResult = this.mvc.perform(get("/login")).andReturn(); + CsrfToken csrfTokenAttribute = (CsrfToken) mvcResult.getRequest().getAttribute(CsrfToken.class.getName()); + + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") + .header(csrfToken.getHeaderName(), csrfTokenAttribute.getToken()) + .param("username", "user") + .param("password", "password"); + // @formatter:on + this.mvc.perform(loginRequest).andExpect(redirectedUrl("/")); + verify(csrfTokenRepository).saveToken(isNull(), any(HttpServletRequest.class), any(HttpServletResponse.class)); + verify(csrfTokenRepository, times(3)).loadDeferredToken(any(HttpServletRequest.class), + any(HttpServletResponse.class)); + verifyNoMoreInteractions(csrfTokenRepository); + } + @Configuration static class AllowHttpMethodsFirewallConfig { diff --git a/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java b/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java index e9220895fb..4da7fbc565 100644 --- a/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java @@ -300,6 +300,39 @@ public class CsrfConfigTests { // @formatter:on } + @Test + public void postWhenUsingCsrfAndXorCsrfTokenRequestProcessorThenOk() throws Exception { + this.spring.configLocations(this.xml("WithXorCsrfTokenRequestAttributeHandler"), this.xml("shared-controllers")) + .autowire(); + // @formatter:off + MvcResult mvcResult = this.mvc.perform(get("/ok")) + .andExpect(status().isOk()) + .andReturn(); + MockHttpSession session = (MockHttpSession) mvcResult.getRequest().getSession(); + CsrfToken csrfToken = (CsrfToken) mvcResult.getRequest().getAttribute("_csrf"); + MockHttpServletRequestBuilder ok = post("/ok") + .header(csrfToken.getHeaderName(), csrfToken.getToken()) + .session(session); + this.mvc.perform(ok).andExpect(status().isOk()); + // @formatter:on + } + + @Test + public void postWhenUsingCsrfAndXorCsrfTokenRequestProcessorWithRawTokenThenForbidden() throws Exception { + this.spring.configLocations(this.xml("WithXorCsrfTokenRequestAttributeHandler"), this.xml("shared-controllers")) + .autowire(); + // @formatter:off + MvcResult mvcResult = this.mvc.perform(get("/ok")) + .andExpect(status().isOk()) + .andReturn(); + MockHttpSession session = (MockHttpSession) mvcResult.getRequest().getSession(); + MockHttpServletRequestBuilder ok = post("/ok") + .with(csrf()) + .session(session); + this.mvc.perform(ok).andExpect(status().isForbidden()); + // @formatter:on + } + @Test public void postWhenHasCsrfTokenButSessionExpiresThenRequestIsCancelledAfterSuccessfulAuthentication() throws Exception { diff --git a/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithXorCsrfTokenRequestAttributeHandler.xml b/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithXorCsrfTokenRequestAttributeHandler.xml new file mode 100644 index 0000000000..beee0c88c7 --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithXorCsrfTokenRequestAttributeHandler.xml @@ -0,0 +1,32 @@ + + + + + + + + + + + + diff --git a/docs/modules/ROOT/pages/servlet/exploits/csrf.adoc b/docs/modules/ROOT/pages/servlet/exploits/csrf.adoc index 39447e3ea4..39a56ec3c0 100644 --- a/docs/modules/ROOT/pages/servlet/exploits/csrf.adoc +++ b/docs/modules/ROOT/pages/servlet/exploits/csrf.adoc @@ -163,13 +163,76 @@ class SecurityConfig { ---- ==== +[[servlet-csrf-configure-request-handler]] +==== Configure CsrfTokenRequestHandler + +Spring Security's https://docs.spring.io/spring-security/site/docs/current/api/org/springframework/security/web/csrf/CsrfFilter.html[CsrfFilter] exposes a https://docs.spring.io/spring-security/site/docs/current/api/org/springframework/security/web/csrf/CsrfToken.html[CsrfToken] as an `HttpServletRequest` attribute named `_csrf` with the help of a https://docs.spring.io/spring-security/site/docs/current/api/org/springframework/security/web/csrf/CsrfTokenRequestHandler.html[CsrfTokenRequestHandler]. +The default implementation is `CsrfTokenRequestAttributeHandler`. + +An alternate implementation `XorCsrfTokenRequestAttributeHandler` is available to provide protection for BREACH (see https://github.com/spring-projects/spring-security/issues/4001[gh-4001]). + +You can configure `XorCsrfTokenRequestAttributeHandler` in XML using the following: + +.Configure BREACH protection XML Configuration +==== +[source,xml] +---- + + + + + +---- +==== + +You can configure `XorCsrfTokenRequestAttributeHandler` in Java Configuration using the following: + +.Configure BREACH protection +==== +.Java +[source,java,role="primary"] +---- +@EnableWebSecurity +public class WebSecurityConfig { + + @Bean + public SecurityFilterChain filterChain(HttpSecurity http) throws Exception { + http + .csrf(csrf -> csrf + .csrfTokenRequestHandler(new XorCsrfTokenRequestAttributeHandler()) + ); + return http.build(); + } +} +---- + +.Kotlin +[source,kotlin,role="secondary"] +---- +@EnableWebSecurity +class SecurityConfig { + + @Bean + open fun filterChain(http: HttpSecurity): SecurityFilterChain { + http { + csrf { + csrfTokenRequestHandler = XorCsrfTokenRequestAttributeHandler() + } + } + return http.build() + } +} +---- +==== + [[servlet-csrf-include]] === Include the CSRF Token In order for the xref:features/exploits/csrf.adoc#csrf-protection-stp[synchronizer token pattern] to protect against CSRF attacks, we must include the actual CSRF token in the HTTP request. This must be included in a part of the request (i.e. form parameter, HTTP header, etc) that is not automatically included in the HTTP request by the browser. -Spring Security's https://docs.spring.io/spring-security/site/docs/current/api/org/springframework/security/web/csrf/CsrfFilter.html[CsrfFilter] exposes a https://docs.spring.io/spring-security/site/docs/current/api/org/springframework/security/web/csrf/CsrfToken.html[CsrfToken] as an `HttpServletRequest` attribute named `_csrf`. +<> that the `CsrfToken` is exposed as a request attribute. This means that any view technology can access the `CsrfToken` to expose the expected token as either a <> or <>. Fortunately, there are integrations listed below that make including the token in <> and <> requests even easier. diff --git a/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java b/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java new file mode 100644 index 0000000000..160b51d19e --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java @@ -0,0 +1,126 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.csrf; + +import java.security.SecureRandom; +import java.util.Base64; +import java.util.function.Supplier; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.security.crypto.codec.Utf8; +import org.springframework.util.Assert; + +/** + * An implementation of the {@link CsrfTokenRequestHandler} interface that is capable of + * masking the value of the {@link CsrfToken} on each request and resolving the raw token + * value from the masked value as either a header or parameter value of the request. + * + * @author Steve Riesenberg + * @since 5.8 + */ +public final class XorCsrfTokenRequestAttributeHandler extends CsrfTokenRequestAttributeHandler { + + private SecureRandom secureRandom = new SecureRandom(); + + /** + * Specifies the {@code SecureRandom} used to generate random bytes that are used to + * mask the value of the {@link CsrfToken} on each request. + * @param secureRandom the {@code SecureRandom} to use to generate random bytes + */ + public void setSecureRandom(SecureRandom secureRandom) { + Assert.notNull(secureRandom, "secureRandom cannot be null"); + this.secureRandom = secureRandom; + } + + @Override + public void handle(HttpServletRequest request, HttpServletResponse response, + Supplier deferredCsrfToken) { + Assert.notNull(request, "request cannot be null"); + Assert.notNull(response, "response cannot be null"); + Assert.notNull(deferredCsrfToken, "deferredCsrfToken cannot be null"); + Supplier updatedCsrfToken = deferCsrfTokenUpdate(deferredCsrfToken); + super.handle(request, response, updatedCsrfToken); + } + + private Supplier deferCsrfTokenUpdate(Supplier csrfTokenSupplier) { + return () -> { + CsrfToken csrfToken = csrfTokenSupplier.get(); + Assert.state(csrfToken != null, "csrfToken supplier returned null"); + String updatedToken = createXoredCsrfToken(this.secureRandom, csrfToken.getToken()); + return new DefaultCsrfToken(csrfToken.getHeaderName(), csrfToken.getParameterName(), updatedToken); + }; + } + + @Override + public String resolveCsrfTokenValue(HttpServletRequest request, CsrfToken csrfToken) { + String actualToken = super.resolveCsrfTokenValue(request, csrfToken); + return getTokenValue(actualToken, csrfToken.getToken()); + } + + private static String getTokenValue(String actualToken, String token) { + byte[] actualBytes; + try { + actualBytes = Base64.getUrlDecoder().decode(actualToken); + } + catch (Exception ex) { + return null; + } + + byte[] tokenBytes = Utf8.encode(token); + int tokenSize = tokenBytes.length; + if (actualBytes.length < tokenSize) { + return null; + } + + // extract token and random bytes + int randomBytesSize = actualBytes.length - tokenSize; + byte[] xoredCsrf = new byte[tokenSize]; + byte[] randomBytes = new byte[randomBytesSize]; + + System.arraycopy(actualBytes, 0, randomBytes, 0, randomBytesSize); + System.arraycopy(actualBytes, randomBytesSize, xoredCsrf, 0, tokenSize); + + byte[] csrfBytes = xorCsrf(randomBytes, xoredCsrf); + return Utf8.decode(csrfBytes); + } + + private static String createXoredCsrfToken(SecureRandom secureRandom, String token) { + byte[] tokenBytes = Utf8.encode(token); + byte[] randomBytes = new byte[tokenBytes.length]; + secureRandom.nextBytes(randomBytes); + + byte[] xoredBytes = xorCsrf(randomBytes, tokenBytes); + byte[] combinedBytes = new byte[tokenBytes.length + randomBytes.length]; + System.arraycopy(randomBytes, 0, combinedBytes, 0, randomBytes.length); + System.arraycopy(xoredBytes, 0, combinedBytes, randomBytes.length, xoredBytes.length); + + return Base64.getUrlEncoder().encodeToString(combinedBytes); + } + + private static byte[] xorCsrf(byte[] randomBytes, byte[] csrfBytes) { + int len = Math.min(randomBytes.length, csrfBytes.length); + byte[] xoredCsrf = new byte[len]; + System.arraycopy(csrfBytes, 0, xoredCsrf, 0, csrfBytes.length); + for (int i = 0; i < len; i++) { + xoredCsrf[i] ^= randomBytes[i]; + } + return xoredCsrf; + } + +} diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java index 8b0fc1449a..4ad810329a 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java @@ -18,6 +18,7 @@ package org.springframework.security.web.csrf; import java.io.IOException; import java.util.Arrays; +import java.util.Base64; import javax.servlet.FilterChain; import javax.servlet.ServletException; @@ -33,6 +34,8 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.access.AccessDeniedException; +import org.springframework.security.crypto.codec.Utf8; import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.util.matcher.RequestMatcher; @@ -362,6 +365,45 @@ public class CsrfFilterTests { verify(this.filterChain).doFilter(this.request, this.response); } + @Test + public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndValidTokenThenSuccess() throws Exception { + given(this.requestMatcher.matches(this.request)).willReturn(false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)) + .willReturn(new TestDeferredCsrfToken(this.token, false)); + XorCsrfTokenRequestAttributeHandler requestHandler = new XorCsrfTokenRequestAttributeHandler(); + requestHandler.setCsrfRequestAttributeName(this.token.getParameterName()); + this.filter.setRequestHandler(requestHandler); + this.filter.doFilter(this.request, this.response, this.filterChain); + assertThat(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); + assertThat(this.request.getAttribute(this.token.getParameterName())).isNotNull(); + verify(this.filterChain).doFilter(this.request, this.response); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + + CsrfToken csrfTokenAttribute = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName()); + byte[] csrfTokenAttributeBytes = Base64.getUrlDecoder().decode(csrfTokenAttribute.getToken()); + byte[] actualTokenBytes = Utf8.encode(this.token.getToken()); + // XOR'd token length is 2x due to containing the random bytes + assertThat(csrfTokenAttributeBytes).hasSize(actualTokenBytes.length * 2); + + given(this.requestMatcher.matches(this.request)).willReturn(true); + this.request.setParameter(this.token.getParameterName(), csrfTokenAttribute.getToken()); + this.filter.doFilter(this.request, this.response, this.filterChain); + verify(this.filterChain, times(2)).doFilter(this.request, this.response); + } + + @Test + public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndRawTokenThenAccessDeniedException() throws Exception { + given(this.requestMatcher.matches(this.request)).willReturn(true); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)) + .willReturn(new TestDeferredCsrfToken(this.token, false)); + XorCsrfTokenRequestAttributeHandler requestHandler = new XorCsrfTokenRequestAttributeHandler(); + this.filter.setRequestHandler(requestHandler); + this.request.setParameter(this.token.getParameterName(), this.token.getToken()); + this.filter.doFilter(this.request, this.response, this.filterChain); + verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(AccessDeniedException.class)); + verifyNoMoreInteractions(this.filterChain); + } + @Test public void setRequireCsrfProtectionMatcherNull() { assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequireCsrfProtectionMatcher(null)); diff --git a/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java b/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java new file mode 100644 index 0000000000..8b03382451 --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java @@ -0,0 +1,203 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.csrf; + +import java.security.SecureRandom; +import java.util.Arrays; +import java.util.Base64; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.stubbing.Answer; + +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.willAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +/** + * Tests for {@link XorCsrfTokenRequestAttributeHandler}. + * + * @author Steve Riesenberg + * @since 5.8 + */ +public class XorCsrfTokenRequestAttributeHandlerTests { + + private static final byte[] XOR_CSRF_TOKEN_BYTES = new byte[] { 1, 1, 1, 96, 99, 98 }; + + private static final String XOR_CSRF_TOKEN_VALUE = Base64.getEncoder().encodeToString(XOR_CSRF_TOKEN_BYTES); + + private MockHttpServletRequest request; + + private MockHttpServletResponse response; + + private CsrfToken token; + + private SecureRandom secureRandom; + + private XorCsrfTokenRequestAttributeHandler handler; + + @BeforeEach + public void setup() { + this.request = new MockHttpServletRequest(); + this.response = new MockHttpServletResponse(); + this.token = new DefaultCsrfToken("headerName", "paramName", "abc"); + this.secureRandom = mock(SecureRandom.class); + this.handler = new XorCsrfTokenRequestAttributeHandler(); + } + + @Test + public void setSecureRandomWhenNullThenThrowsIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.handler.setSecureRandom(null)) + .withMessage("secureRandom cannot be null"); + // @formatter:on + } + + @Test + public void handleWhenRequestIsNullThenThrowsIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.handler.handle(null, this.response, () -> this.token)) + .withMessage("request cannot be null"); + // @formatter:on + } + + @Test + public void handleWhenResponseIsNullThenThrowsIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.handler.handle(this.request, null, () -> this.token)) + .withMessage("response cannot be null"); + // @formatter:on + } + + @Test + public void handleWhenCsrfTokenSupplierIsNullThenThrowsIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.handler.handle(this.request, this.response, null)) + .withMessage("deferredCsrfToken cannot be null"); + // @formatter:on + } + + @Test + public void handleWhenCsrfTokenIsNullThenThrowsIllegalStateException() { + // @formatter:off + assertThatIllegalStateException() + .isThrownBy(() -> this.handler.handle(this.request, this.response, () -> null)) + .withMessage("csrfToken supplier returned null"); + // @formatter:on + } + + @Test + public void handleWhenCsrfRequestAttributeSetThenUsed() { + willAnswer(fillByteArray()).given(this.secureRandom).nextBytes(anyByteArray()); + + this.handler.setSecureRandom(this.secureRandom); + this.handler.setCsrfRequestAttributeName("_csrf"); + this.handler.handle(this.request, this.response, () -> this.token); + assertThat(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); + assertThat(this.request.getAttribute("_csrf")).isNotNull(); + + CsrfToken csrfTokenAttribute = (CsrfToken) this.request.getAttribute("_csrf"); + assertThat(csrfTokenAttribute.getToken()).isEqualTo(XOR_CSRF_TOKEN_VALUE); + } + + @Test + public void handleWhenSecureRandomSetThenUsed() { + this.handler.setSecureRandom(this.secureRandom); + this.handler.handle(this.request, this.response, () -> this.token); + verify(this.secureRandom).nextBytes(anyByteArray()); + verifyNoMoreInteractions(this.secureRandom); + } + + @Test + public void handleWhenValidParametersThenRequestAttributesSet() { + willAnswer(fillByteArray()).given(this.secureRandom).nextBytes(anyByteArray()); + + this.handler.setSecureRandom(this.secureRandom); + this.handler.handle(this.request, this.response, () -> this.token); + verify(this.secureRandom).nextBytes(anyByteArray()); + assertThat(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); + assertThat(this.request.getAttribute(this.token.getParameterName())).isNotNull(); + + CsrfToken csrfTokenAttribute = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName()); + assertThat(csrfTokenAttribute.getToken()).isEqualTo(XOR_CSRF_TOKEN_VALUE); + } + + @Test + public void resolveCsrfTokenValueWhenRequestIsNullThenThrowsIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.handler.resolveCsrfTokenValue(null, this.token)) + .withMessage("request cannot be null"); + } + + @Test + public void resolveCsrfTokenValueWhenCsrfTokenIsNullThenThrowsIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.handler.resolveCsrfTokenValue(this.request, null)) + .withMessage("csrfToken cannot be null"); + } + + @Test + public void resolveCsrfTokenValueWhenTokenNotSetThenReturnsNull() { + String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token); + assertThat(tokenValue).isNull(); + } + + @Test + public void resolveCsrfTokenValueWhenParameterSetThenReturnsTokenValue() { + this.request.setParameter(this.token.getParameterName(), XOR_CSRF_TOKEN_VALUE); + String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token); + assertThat(tokenValue).isEqualTo(this.token.getToken()); + } + + @Test + public void resolveCsrfTokenValueWhenHeaderSetThenReturnsTokenValue() { + this.request.addHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE); + String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token); + assertThat(tokenValue).isEqualTo(this.token.getToken()); + } + + @Test + public void resolveCsrfTokenValueWhenHeaderAndParameterSetThenHeaderIsPreferred() { + this.request.addHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE); + this.request.setParameter(this.token.getParameterName(), "invalid"); + String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token); + assertThat(tokenValue).isEqualTo(this.token.getToken()); + } + + private static Answer fillByteArray() { + return (invocation) -> { + byte[] bytes = invocation.getArgument(0); + Arrays.fill(bytes, (byte) 1); + return null; + }; + } + + private static byte[] anyByteArray() { + return any(byte[].class); + } + +}