Merge branch '5.8.x'

# Conflicts:
#	docs/modules/ROOT/pages/servlet/exploits/csrf.adoc
This commit is contained in:
Steve Riesenberg 2022-10-05 14:46:15 -05:00
commit 8b490de08d
No known key found for this signature in database
GPG Key ID: 5F311AB48A55D521
7 changed files with 545 additions and 3 deletions

View File

@ -48,6 +48,7 @@ import org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler;
import org.springframework.security.web.csrf.CsrfTokenRequestHandler; import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
import org.springframework.security.web.csrf.DefaultCsrfToken; import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.security.web.csrf.DeferredCsrfToken;
import org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler;
import org.springframework.security.web.firewall.StrictHttpFirewall; import org.springframework.security.web.firewall.StrictHttpFirewall;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache; import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.security.web.savedrequest.RequestCache;
@ -64,6 +65,7 @@ import org.springframework.web.servlet.support.RequestDataValueProcessor;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.not;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
@ -429,7 +431,7 @@ public class CsrfConfigurerTests {
} }
@Test @Test
public void getLoginWhenCsrfTokenRequestHandlerSetThenRespondsWithNormalCsrfToken() throws Exception { public void getLoginWhenCsrfTokenRequestAttributeHandlerSetThenRespondsWithNormalCsrfToken() throws Exception {
CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class); CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class);
CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class))) given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class)))
@ -444,7 +446,7 @@ public class CsrfConfigurerTests {
} }
@Test @Test
public void loginWhenCsrfTokenRequestHandlerSetAndNormalCsrfTokenThenSuccess() throws Exception { public void loginWhenCsrfTokenRequestAttributeHandlerSetAndNormalCsrfTokenThenSuccess() throws Exception {
CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class); CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class);
given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class))) given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class)))
@ -466,6 +468,47 @@ public class CsrfConfigurerTests {
verifyNoMoreInteractions(csrfTokenRepository); verifyNoMoreInteractions(csrfTokenRepository);
} }
@Test
public void getLoginWhenXorCsrfTokenRequestAttributeHandlerSetThenRespondsWithMaskedCsrfToken() throws Exception {
CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class);
CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class)))
.willReturn(new TestDeferredCsrfToken(csrfToken));
CsrfTokenRequestHandlerConfig.REPO = csrfTokenRepository;
CsrfTokenRequestHandlerConfig.HANDLER = new XorCsrfTokenRequestAttributeHandler();
this.spring.register(CsrfTokenRequestHandlerConfig.class, BasicController.class).autowire();
this.mvc.perform(get("/login")).andExpect(status().isOk())
.andExpect(content().string(not(containsString(csrfToken.getToken()))));
verify(csrfTokenRepository).loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class));
verifyNoMoreInteractions(csrfTokenRepository);
}
@Test
public void loginWhenXorCsrfTokenRequestAttributeHandlerSetAndMaskedCsrfTokenThenSuccess() throws Exception {
CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class);
given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class)))
.willReturn(new TestDeferredCsrfToken(csrfToken));
CsrfTokenRequestHandlerConfig.REPO = csrfTokenRepository;
CsrfTokenRequestHandlerConfig.HANDLER = new XorCsrfTokenRequestAttributeHandler();
this.spring.register(CsrfTokenRequestHandlerConfig.class, BasicController.class).autowire();
MvcResult mvcResult = this.mvc.perform(get("/login")).andReturn();
CsrfToken csrfTokenAttribute = (CsrfToken) mvcResult.getRequest().getAttribute(CsrfToken.class.getName());
// @formatter:off
MockHttpServletRequestBuilder loginRequest = post("/login")
.header(csrfToken.getHeaderName(), csrfTokenAttribute.getToken())
.param("username", "user")
.param("password", "password");
// @formatter:on
this.mvc.perform(loginRequest).andExpect(redirectedUrl("/"));
verify(csrfTokenRepository).saveToken(isNull(), any(HttpServletRequest.class), any(HttpServletResponse.class));
verify(csrfTokenRepository, times(3)).loadDeferredToken(any(HttpServletRequest.class),
any(HttpServletResponse.class));
verifyNoMoreInteractions(csrfTokenRepository);
}
@Configuration @Configuration
static class AllowHttpMethodsFirewallConfig { static class AllowHttpMethodsFirewallConfig {

View File

@ -300,6 +300,39 @@ public class CsrfConfigTests {
// @formatter:on // @formatter:on
} }
@Test
public void postWhenUsingCsrfAndXorCsrfTokenRequestProcessorThenOk() throws Exception {
this.spring.configLocations(this.xml("WithXorCsrfTokenRequestAttributeHandler"), this.xml("shared-controllers"))
.autowire();
// @formatter:off
MvcResult mvcResult = this.mvc.perform(get("/ok"))
.andExpect(status().isOk())
.andReturn();
MockHttpSession session = (MockHttpSession) mvcResult.getRequest().getSession();
CsrfToken csrfToken = (CsrfToken) mvcResult.getRequest().getAttribute("_csrf");
MockHttpServletRequestBuilder ok = post("/ok")
.header(csrfToken.getHeaderName(), csrfToken.getToken())
.session(session);
this.mvc.perform(ok).andExpect(status().isOk());
// @formatter:on
}
@Test
public void postWhenUsingCsrfAndXorCsrfTokenRequestProcessorWithRawTokenThenForbidden() throws Exception {
this.spring.configLocations(this.xml("WithXorCsrfTokenRequestAttributeHandler"), this.xml("shared-controllers"))
.autowire();
// @formatter:off
MvcResult mvcResult = this.mvc.perform(get("/ok"))
.andExpect(status().isOk())
.andReturn();
MockHttpSession session = (MockHttpSession) mvcResult.getRequest().getSession();
MockHttpServletRequestBuilder ok = post("/ok")
.with(csrf())
.session(session);
this.mvc.perform(ok).andExpect(status().isForbidden());
// @formatter:on
}
@Test @Test
public void postWhenHasCsrfTokenButSessionExpiresThenRequestIsCancelledAfterSuccessfulAuthentication() public void postWhenHasCsrfTokenButSessionExpiresThenRequestIsCancelledAfterSuccessfulAuthentication()
throws Exception { throws Exception {

View File

@ -0,0 +1,32 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ Copyright 2002-2018 the original author or authors.
~
~ Licensed under the Apache License, Version 2.0 (the "License");
~ you may not use this file except in compliance with the License.
~ You may obtain a copy of the License at
~
~ https://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS,
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->
<b:beans xmlns:b="http://www.springframework.org/schema/beans"
xmlns:p="http://www.springframework.org/schema/p"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns="http://www.springframework.org/schema/security"
xsi:schemaLocation="http://www.springframework.org/schema/security https://www.springframework.org/schema/security/spring-security.xsd
http://www.springframework.org/schema/beans https://www.springframework.org/schema/beans/spring-beans.xsd">
<http auto-config="true">
<csrf request-handler-ref="requestHandler"/>
</http>
<b:bean id="requestHandler" class="org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler"
p:csrfRequestAttributeName="_csrf"/>
<b:import resource="CsrfConfigTests-shared-userservice.xml"/>
</b:beans>

View File

@ -164,13 +164,76 @@ class SecurityConfig {
---- ----
==== ====
[[servlet-csrf-configure-request-handler]]
==== Configure CsrfTokenRequestHandler
Spring Security's https://docs.spring.io/spring-security/site/docs/current/api/org/springframework/security/web/csrf/CsrfFilter.html[`CsrfFilter`] exposes a https://docs.spring.io/spring-security/site/docs/current/api/org/springframework/security/web/csrf/CsrfToken.html[`CsrfToken`] as an `HttpServletRequest` attribute named `_csrf` with the help of a https://docs.spring.io/spring-security/site/docs/current/api/org/springframework/security/web/csrf/CsrfTokenRequestHandler.html[CsrfTokenRequestHandler].
The default implementation is `CsrfTokenRequestAttributeHandler`.
An alternate implementation `XorCsrfTokenRequestAttributeHandler` is available to provide protection for BREACH (see https://github.com/spring-projects/spring-security/issues/4001[gh-4001]).
You can configure `XorCsrfTokenRequestAttributeHandler` in XML using the following:
.Configure BREACH protection XML Configuration
====
[source,xml]
----
<http>
<!-- ... -->
<csrf request-handler-ref="requestHandler"/>
</http>
<b:bean id="requestHandler"
class="org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler"/>
----
====
You can configure `XorCsrfTokenRequestAttributeHandler` in Java Configuration using the following:
.Configure BREACH protection
====
.Java
[source,java,role="primary"]
----
@EnableWebSecurity
public class WebSecurityConfig {
@Bean
public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
http
.csrf(csrf -> csrf
.csrfTokenRequestHandler(new XorCsrfTokenRequestAttributeHandler())
);
return http.build();
}
}
----
.Kotlin
[source,kotlin,role="secondary"]
----
@EnableWebSecurity
class SecurityConfig {
@Bean
open fun filterChain(http: HttpSecurity): SecurityFilterChain {
http {
csrf {
csrfTokenRequestHandler = XorCsrfTokenRequestAttributeHandler()
}
}
return http.build()
}
}
----
====
[[servlet-csrf-include]] [[servlet-csrf-include]]
=== Include the CSRF Token === Include the CSRF Token
For the xref:features/exploits/csrf.adoc#csrf-protection-stp[synchronizer token pattern] to protect against CSRF attacks, we must include the actual CSRF token in the HTTP request. For the xref:features/exploits/csrf.adoc#csrf-protection-stp[synchronizer token pattern] to protect against CSRF attacks, we must include the actual CSRF token in the HTTP request.
This must be included in a part of the request (a form parameter, an HTTP header, or other part) that is not automatically included in the HTTP request by the browser. This must be included in a part of the request (a form parameter, an HTTP header, or other part) that is not automatically included in the HTTP request by the browser.
Spring Security's https://docs.spring.io/spring-security/site/docs/current/api/org/springframework/security/web/csrf/CsrfFilter.html[`CsrfFilter`] exposes a https://docs.spring.io/spring-security/site/docs/current/api/org/springframework/security/web/csrf/CsrfToken.html[`CsrfToken`] as an `HttpServletRequest` attribute named `_csrf`. <<servlet-csrf-configure-request-handler,We've seen>> that the `CsrfToken` is exposed as a request attribute.
This means that any view technology can access the `CsrfToken` to expose the expected token as either a <<servlet-csrf-include-form-attr,form>> or <<servlet-csrf-include-ajax-meta-attr,meta tag>>. This means that any view technology can access the `CsrfToken` to expose the expected token as either a <<servlet-csrf-include-form-attr,form>> or <<servlet-csrf-include-ajax-meta-attr,meta tag>>.
Fortunately, there are integrations listed later in this chapter that make including the token in <<servlet-csrf-include-form,form>> and <<servlet-csrf-include-ajax,ajax>> requests even easier. Fortunately, there are integrations listed later in this chapter that make including the token in <<servlet-csrf-include-form,form>> and <<servlet-csrf-include-ajax,ajax>> requests even easier.

View File

@ -0,0 +1,126 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.csrf;
import java.security.SecureRandom;
import java.util.Base64;
import java.util.function.Supplier;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.security.crypto.codec.Utf8;
import org.springframework.util.Assert;
/**
* An implementation of the {@link CsrfTokenRequestHandler} interface that is capable of
* masking the value of the {@link CsrfToken} on each request and resolving the raw token
* value from the masked value as either a header or parameter value of the request.
*
* @author Steve Riesenberg
* @since 5.8
*/
public final class XorCsrfTokenRequestAttributeHandler extends CsrfTokenRequestAttributeHandler {
private SecureRandom secureRandom = new SecureRandom();
/**
* Specifies the {@code SecureRandom} used to generate random bytes that are used to
* mask the value of the {@link CsrfToken} on each request.
* @param secureRandom the {@code SecureRandom} to use to generate random bytes
*/
public void setSecureRandom(SecureRandom secureRandom) {
Assert.notNull(secureRandom, "secureRandom cannot be null");
this.secureRandom = secureRandom;
}
@Override
public void handle(HttpServletRequest request, HttpServletResponse response,
Supplier<CsrfToken> deferredCsrfToken) {
Assert.notNull(request, "request cannot be null");
Assert.notNull(response, "response cannot be null");
Assert.notNull(deferredCsrfToken, "deferredCsrfToken cannot be null");
Supplier<CsrfToken> updatedCsrfToken = deferCsrfTokenUpdate(deferredCsrfToken);
super.handle(request, response, updatedCsrfToken);
}
private Supplier<CsrfToken> deferCsrfTokenUpdate(Supplier<CsrfToken> csrfTokenSupplier) {
return () -> {
CsrfToken csrfToken = csrfTokenSupplier.get();
Assert.state(csrfToken != null, "csrfToken supplier returned null");
String updatedToken = createXoredCsrfToken(this.secureRandom, csrfToken.getToken());
return new DefaultCsrfToken(csrfToken.getHeaderName(), csrfToken.getParameterName(), updatedToken);
};
}
@Override
public String resolveCsrfTokenValue(HttpServletRequest request, CsrfToken csrfToken) {
String actualToken = super.resolveCsrfTokenValue(request, csrfToken);
return getTokenValue(actualToken, csrfToken.getToken());
}
private static String getTokenValue(String actualToken, String token) {
byte[] actualBytes;
try {
actualBytes = Base64.getUrlDecoder().decode(actualToken);
}
catch (Exception ex) {
return null;
}
byte[] tokenBytes = Utf8.encode(token);
int tokenSize = tokenBytes.length;
if (actualBytes.length < tokenSize) {
return null;
}
// extract token and random bytes
int randomBytesSize = actualBytes.length - tokenSize;
byte[] xoredCsrf = new byte[tokenSize];
byte[] randomBytes = new byte[randomBytesSize];
System.arraycopy(actualBytes, 0, randomBytes, 0, randomBytesSize);
System.arraycopy(actualBytes, randomBytesSize, xoredCsrf, 0, tokenSize);
byte[] csrfBytes = xorCsrf(randomBytes, xoredCsrf);
return Utf8.decode(csrfBytes);
}
private static String createXoredCsrfToken(SecureRandom secureRandom, String token) {
byte[] tokenBytes = Utf8.encode(token);
byte[] randomBytes = new byte[tokenBytes.length];
secureRandom.nextBytes(randomBytes);
byte[] xoredBytes = xorCsrf(randomBytes, tokenBytes);
byte[] combinedBytes = new byte[tokenBytes.length + randomBytes.length];
System.arraycopy(randomBytes, 0, combinedBytes, 0, randomBytes.length);
System.arraycopy(xoredBytes, 0, combinedBytes, randomBytes.length, xoredBytes.length);
return Base64.getUrlEncoder().encodeToString(combinedBytes);
}
private static byte[] xorCsrf(byte[] randomBytes, byte[] csrfBytes) {
int len = Math.min(randomBytes.length, csrfBytes.length);
byte[] xoredCsrf = new byte[len];
System.arraycopy(csrfBytes, 0, xoredCsrf, 0, csrfBytes.length);
for (int i = 0; i < len; i++) {
xoredCsrf[i] ^= randomBytes[i];
}
return xoredCsrf;
}
}

View File

@ -18,6 +18,7 @@ package org.springframework.security.web.csrf;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Base64;
import jakarta.servlet.FilterChain; import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException; import jakarta.servlet.ServletException;
@ -32,6 +33,8 @@ import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.crypto.codec.Utf8;
import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher;
@ -363,6 +366,45 @@ public class CsrfFilterTests {
verify(this.filterChain).doFilter(this.request, this.response); verify(this.filterChain).doFilter(this.request, this.response);
} }
@Test
public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndValidTokenThenSuccess() throws Exception {
given(this.requestMatcher.matches(this.request)).willReturn(false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
XorCsrfTokenRequestAttributeHandler requestHandler = new XorCsrfTokenRequestAttributeHandler();
requestHandler.setCsrfRequestAttributeName(this.token.getParameterName());
this.filter.setRequestHandler(requestHandler);
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
assertThat(this.request.getAttribute(this.token.getParameterName())).isNotNull();
verify(this.filterChain).doFilter(this.request, this.response);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
CsrfToken csrfTokenAttribute = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
byte[] csrfTokenAttributeBytes = Base64.getUrlDecoder().decode(csrfTokenAttribute.getToken());
byte[] actualTokenBytes = Utf8.encode(this.token.getToken());
// XOR'd token length is 2x due to containing the random bytes
assertThat(csrfTokenAttributeBytes).hasSize(actualTokenBytes.length * 2);
given(this.requestMatcher.matches(this.request)).willReturn(true);
this.request.setParameter(this.token.getParameterName(), csrfTokenAttribute.getToken());
this.filter.doFilter(this.request, this.response, this.filterChain);
verify(this.filterChain, times(2)).doFilter(this.request, this.response);
}
@Test
public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndRawTokenThenAccessDeniedException() throws Exception {
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
XorCsrfTokenRequestAttributeHandler requestHandler = new XorCsrfTokenRequestAttributeHandler();
this.filter.setRequestHandler(requestHandler);
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
this.filter.doFilter(this.request, this.response, this.filterChain);
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(AccessDeniedException.class));
verifyNoMoreInteractions(this.filterChain);
}
@Test @Test
public void setRequireCsrfProtectionMatcherNull() { public void setRequireCsrfProtectionMatcherNull() {
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequireCsrfProtectionMatcher(null)); assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequireCsrfProtectionMatcher(null));

View File

@ -0,0 +1,203 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.csrf;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.Base64;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.stubbing.Answer;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.willAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
/**
* Tests for {@link XorCsrfTokenRequestAttributeHandler}.
*
* @author Steve Riesenberg
* @since 5.8
*/
public class XorCsrfTokenRequestAttributeHandlerTests {
private static final byte[] XOR_CSRF_TOKEN_BYTES = new byte[] { 1, 1, 1, 96, 99, 98 };
private static final String XOR_CSRF_TOKEN_VALUE = Base64.getEncoder().encodeToString(XOR_CSRF_TOKEN_BYTES);
private MockHttpServletRequest request;
private MockHttpServletResponse response;
private CsrfToken token;
private SecureRandom secureRandom;
private XorCsrfTokenRequestAttributeHandler handler;
@BeforeEach
public void setup() {
this.request = new MockHttpServletRequest();
this.response = new MockHttpServletResponse();
this.token = new DefaultCsrfToken("headerName", "paramName", "abc");
this.secureRandom = mock(SecureRandom.class);
this.handler = new XorCsrfTokenRequestAttributeHandler();
}
@Test
public void setSecureRandomWhenNullThenThrowsIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.handler.setSecureRandom(null))
.withMessage("secureRandom cannot be null");
// @formatter:on
}
@Test
public void handleWhenRequestIsNullThenThrowsIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.handler.handle(null, this.response, () -> this.token))
.withMessage("request cannot be null");
// @formatter:on
}
@Test
public void handleWhenResponseIsNullThenThrowsIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.handler.handle(this.request, null, () -> this.token))
.withMessage("response cannot be null");
// @formatter:on
}
@Test
public void handleWhenCsrfTokenSupplierIsNullThenThrowsIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.handler.handle(this.request, this.response, null))
.withMessage("deferredCsrfToken cannot be null");
// @formatter:on
}
@Test
public void handleWhenCsrfTokenIsNullThenThrowsIllegalStateException() {
// @formatter:off
assertThatIllegalStateException()
.isThrownBy(() -> this.handler.handle(this.request, this.response, () -> null))
.withMessage("csrfToken supplier returned null");
// @formatter:on
}
@Test
public void handleWhenCsrfRequestAttributeSetThenUsed() {
willAnswer(fillByteArray()).given(this.secureRandom).nextBytes(anyByteArray());
this.handler.setSecureRandom(this.secureRandom);
this.handler.setCsrfRequestAttributeName("_csrf");
this.handler.handle(this.request, this.response, () -> this.token);
assertThat(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
assertThat(this.request.getAttribute("_csrf")).isNotNull();
CsrfToken csrfTokenAttribute = (CsrfToken) this.request.getAttribute("_csrf");
assertThat(csrfTokenAttribute.getToken()).isEqualTo(XOR_CSRF_TOKEN_VALUE);
}
@Test
public void handleWhenSecureRandomSetThenUsed() {
this.handler.setSecureRandom(this.secureRandom);
this.handler.handle(this.request, this.response, () -> this.token);
verify(this.secureRandom).nextBytes(anyByteArray());
verifyNoMoreInteractions(this.secureRandom);
}
@Test
public void handleWhenValidParametersThenRequestAttributesSet() {
willAnswer(fillByteArray()).given(this.secureRandom).nextBytes(anyByteArray());
this.handler.setSecureRandom(this.secureRandom);
this.handler.handle(this.request, this.response, () -> this.token);
verify(this.secureRandom).nextBytes(anyByteArray());
assertThat(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
assertThat(this.request.getAttribute(this.token.getParameterName())).isNotNull();
CsrfToken csrfTokenAttribute = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
assertThat(csrfTokenAttribute.getToken()).isEqualTo(XOR_CSRF_TOKEN_VALUE);
}
@Test
public void resolveCsrfTokenValueWhenRequestIsNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.handler.resolveCsrfTokenValue(null, this.token))
.withMessage("request cannot be null");
}
@Test
public void resolveCsrfTokenValueWhenCsrfTokenIsNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.handler.resolveCsrfTokenValue(this.request, null))
.withMessage("csrfToken cannot be null");
}
@Test
public void resolveCsrfTokenValueWhenTokenNotSetThenReturnsNull() {
String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token);
assertThat(tokenValue).isNull();
}
@Test
public void resolveCsrfTokenValueWhenParameterSetThenReturnsTokenValue() {
this.request.setParameter(this.token.getParameterName(), XOR_CSRF_TOKEN_VALUE);
String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token);
assertThat(tokenValue).isEqualTo(this.token.getToken());
}
@Test
public void resolveCsrfTokenValueWhenHeaderSetThenReturnsTokenValue() {
this.request.addHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE);
String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token);
assertThat(tokenValue).isEqualTo(this.token.getToken());
}
@Test
public void resolveCsrfTokenValueWhenHeaderAndParameterSetThenHeaderIsPreferred() {
this.request.addHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE);
this.request.setParameter(this.token.getParameterName(), "invalid");
String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token);
assertThat(tokenValue).isEqualTo(this.token.getToken());
}
private static Answer<Void> fillByteArray() {
return (invocation) -> {
byte[] bytes = invocation.getArgument(0);
Arrays.fill(bytes, (byte) 1);
return null;
};
}
private static byte[] anyByteArray() {
return any(byte[].class);
}
}