Merge branch '6.0.x'

This commit is contained in:
Steve Riesenberg 2023-01-26 15:55:41 -06:00
commit 6abbdd3654
No known key found for this signature in database
GPG Key ID: 5F311AB48A55D521
14 changed files with 518 additions and 61 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -41,7 +41,7 @@ import org.springframework.security.messaging.access.intercept.AuthorizationChan
import org.springframework.security.messaging.access.intercept.MessageMatcherDelegatingAuthorizationManager;
import org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolver;
import org.springframework.security.messaging.context.SecurityContextChannelInterceptor;
import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
import org.springframework.security.messaging.web.csrf.XorCsrfChannelInterceptor;
import org.springframework.security.messaging.web.socket.server.CsrfTokenHandshakeInterceptor;
import org.springframework.util.Assert;
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
@ -59,6 +59,8 @@ final class WebSocketMessageBrokerSecurityConfiguration
private static final String SIMPLE_URL_HANDLER_MAPPING_BEAN_NAME = "stompWebSocketHandlerMapping";
private static final String CSRF_CHANNEL_INTERCEPTOR_BEAN_NAME = "csrfChannelInterceptor";
private MessageMatcherDelegatingAuthorizationManager b;
private static final AuthorizationManager<Message<?>> ANY_MESSAGE_AUTHENTICATED = MessageMatcherDelegatingAuthorizationManager
@ -69,7 +71,7 @@ final class WebSocketMessageBrokerSecurityConfiguration
private final SecurityContextChannelInterceptor securityContextChannelInterceptor = new SecurityContextChannelInterceptor();
private final ChannelInterceptor csrfChannelInterceptor = new CsrfChannelInterceptor();
private ChannelInterceptor csrfChannelInterceptor = new XorCsrfChannelInterceptor();
private AuthorizationManager<Message<?>> authorizationManager = ANY_MESSAGE_AUTHENTICATED;
@ -90,6 +92,12 @@ final class WebSocketMessageBrokerSecurityConfiguration
@Override
public void configureClientInboundChannel(ChannelRegistration registration) {
ChannelInterceptor csrfChannelInterceptor = getBeanOrNull(CSRF_CHANNEL_INTERCEPTOR_BEAN_NAME,
ChannelInterceptor.class);
if (csrfChannelInterceptor != null) {
this.csrfChannelInterceptor = csrfChannelInterceptor;
}
AuthorizationManager<Message<?>> manager = this.authorizationManager;
if (!this.observationRegistry.isNoop()) {
manager = new ObservationAuthorizationManager<>(this.observationRegistry, manager);

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -60,6 +60,7 @@ import org.springframework.security.messaging.context.SecurityContextChannelInte
import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.csrf.DeferredCsrfToken;
import org.springframework.security.web.csrf.MissingCsrfTokenException;
import org.springframework.stereotype.Controller;
import org.springframework.test.util.ReflectionTestUtils;
@ -78,6 +79,7 @@ import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSo
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
@ -283,7 +285,7 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
private void assertHandshake(HttpServletRequest request) {
TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class);
assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(this.token);
assertThatCsrfToken(handshakeHandler.attributes.get(CsrfToken.class.getName())).isEqualTo(this.token);
assertThat(handshakeHandler.attributes.get(this.sessionAttr))
.isEqualTo(request.getSession().getAttribute(this.sessionAttr));
}
@ -305,7 +307,7 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket");
request.setRequestURI(mapping + "/289/tpyx6mde/websocket");
request.getSession().setAttribute(this.sessionAttr, "sessionValue");
request.setAttribute(CsrfToken.class.getName(), this.token);
request.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token));
return request;
}

View File

@ -0,0 +1,43 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.config.annotation.web.socket;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DeferredCsrfToken;
/**
* @author Steve Riesenberg
*/
final class TestDeferredCsrfToken implements DeferredCsrfToken {
private final CsrfToken csrfToken;
TestDeferredCsrfToken(CsrfToken csrfToken) {
this.csrfToken = csrfToken;
}
@Override
public CsrfToken get() {
return this.csrfToken;
}
@Override
public boolean isGenerated() {
return false;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -66,9 +66,10 @@ import org.springframework.security.messaging.access.intercept.AuthorizationChan
import org.springframework.security.messaging.access.intercept.MessageAuthorizationContext;
import org.springframework.security.messaging.access.intercept.MessageMatcherDelegatingAuthorizationManager;
import org.springframework.security.messaging.context.SecurityContextChannelInterceptor;
import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
import org.springframework.security.messaging.web.csrf.XorCsrfChannelInterceptor;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.csrf.DeferredCsrfToken;
import org.springframework.security.web.csrf.MissingCsrfTokenException;
import org.springframework.stereotype.Controller;
import org.springframework.test.util.ReflectionTestUtils;
@ -91,9 +92,12 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.fail;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.verify;
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
public class WebSocketMessageBrokerSecurityConfigurationTests {
private static final String XOR_CSRF_TOKEN_VALUE = "wpe7zB62-NCpcA==";
AnnotationConfigWebApplicationContext context;
Authentication messageUser;
@ -196,7 +200,7 @@ public class WebSocketMessageBrokerSecurityConfigurationTests {
MessageChannel messageChannel = clientInboundChannel();
Stream<Class<? extends ChannelInterceptor>> interceptors = ((AbstractMessageChannel) messageChannel)
.getInterceptors().stream().map(ChannelInterceptor::getClass);
assertThat(interceptors).contains(CsrfChannelInterceptor.class);
assertThat(interceptors).contains(XorCsrfChannelInterceptor.class);
}
@Test
@ -236,7 +240,7 @@ public class WebSocketMessageBrokerSecurityConfigurationTests {
public void messagesContextWebSocketUseSecurityContextHolderStrategy() {
loadConfig(WebSocketSecurityConfig.class, SecurityContextChangedListenerConfig.class);
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
headers.setNativeHeader(this.token.getHeaderName(), this.token.getToken());
headers.setNativeHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE);
Message<?> message = message(headers, "/authenticated");
headers.getSessionAttributes().put(CsrfToken.class.getName(), this.token);
MessageChannel messageChannel = clientInboundChannel();
@ -366,7 +370,7 @@ public class WebSocketMessageBrokerSecurityConfigurationTests {
private void assertHandshake(HttpServletRequest request) {
TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class);
assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(this.token);
assertThatCsrfToken(handshakeHandler.attributes.get(CsrfToken.class.getName())).isEqualTo(this.token);
assertThat(handshakeHandler.attributes.get(this.sessionAttr))
.isEqualTo(request.getSession().getAttribute(this.sessionAttr));
}
@ -388,7 +392,7 @@ public class WebSocketMessageBrokerSecurityConfigurationTests {
request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket");
request.setRequestURI(mapping + "/289/tpyx6mde/websocket");
request.getSession().setAttribute(this.sessionAttr, "sessionValue");
request.setAttribute(CsrfToken.class.getName(), this.token);
request.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token));
return request;
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -61,6 +61,7 @@ import org.springframework.security.test.context.annotation.SecurityTestExecutio
import org.springframework.security.test.context.support.WithMockUser;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.csrf.DeferredCsrfToken;
import org.springframework.security.web.csrf.InvalidCsrfTokenException;
import org.springframework.stereotype.Controller;
import org.springframework.test.context.junit.jupiter.SpringExtension;
@ -77,6 +78,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.verify;
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
/**
@ -381,12 +383,14 @@ public class WebSocketMessageBrokerConfigTests {
MockMvc mvc = MockMvcBuilders.webAppContextSetup(context).build();
String csrfAttributeName = CsrfToken.class.getName();
String customAttributeName = this.getClass().getName();
MvcResult result = mvc.perform(get("/app").requestAttr(csrfAttributeName, this.token)
.sessionAttr(customAttributeName, "attributeValue")).andReturn();
MvcResult result = mvc.perform(
get("/app").requestAttr(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token))
.sessionAttr(customAttributeName, "attributeValue"))
.andReturn();
CsrfToken handshakeToken = (CsrfToken) this.testHandshakeHandler.attributes.get(csrfAttributeName);
String handshakeValue = (String) this.testHandshakeHandler.attributes.get(customAttributeName);
String sessionValue = (String) result.getRequest().getSession().getAttribute(customAttributeName);
assertThat(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
assertThatCsrfToken(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
assertThat(handshakeValue).isEqualTo(sessionValue)
.withFailMessage("Explicitly listed session variables are not overridden");
}
@ -398,12 +402,13 @@ public class WebSocketMessageBrokerConfigTests {
MockMvc mvc = MockMvcBuilders.webAppContextSetup(context).build();
String csrfAttributeName = CsrfToken.class.getName();
String customAttributeName = this.getClass().getName();
MvcResult result = mvc.perform(get("/app/289/tpyx6mde/websocket").requestAttr(csrfAttributeName, this.token)
MvcResult result = mvc.perform(get("/app/289/tpyx6mde/websocket")
.requestAttr(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token))
.sessionAttr(customAttributeName, "attributeValue")).andReturn();
CsrfToken handshakeToken = (CsrfToken) this.testHandshakeHandler.attributes.get(csrfAttributeName);
String handshakeValue = (String) this.testHandshakeHandler.attributes.get(customAttributeName);
String sessionValue = (String) result.getRequest().getSession().getAttribute(customAttributeName);
assertThat(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
assertThatCsrfToken(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
assertThat(handshakeValue).isEqualTo(sessionValue)
.withFailMessage("Explicitly listed session variables are not overridden");
}
@ -526,6 +531,26 @@ public class WebSocketMessageBrokerConfigTests {
return SecurityContextHolder.getContextHolderStrategy();
}
private static final class TestDeferredCsrfToken implements DeferredCsrfToken {
private final CsrfToken csrfToken;
TestDeferredCsrfToken(CsrfToken csrfToken) {
this.csrfToken = csrfToken;
}
@Override
public CsrfToken get() {
return this.csrfToken;
}
@Override
public boolean isGenerated() {
return false;
}
}
@Controller
static class MessageController {

View File

@ -5,6 +5,7 @@
* xref:migration/index.adoc[Migrating to 6.0]
** xref:migration/servlet/index.adoc[Servlet Migrations]
*** xref:migration/servlet/session-management.adoc[Session Management]
*** xref:migration/servlet/exploits.adoc[Exploit Protection]
*** xref:migration/servlet/authentication.adoc[Authentication]
*** xref:migration/servlet/authorization.adoc[Authorization]
** xref:migration/reactive.adoc[Reactive Migrations]

View File

@ -0,0 +1,11 @@
= Exploit Protection Migrations
The following steps relate to how to finish migrating exploit protection support.
== CSRF BREACH with WebSocket support
In Spring Security 5.8, the default `ChannelInterceptor` for making the `CsrfToken` available with xref:servlet/integrations/websocket.adoc[WebSocket Security] is `CsrfChannelInterceptor`.
`XorCsrfChannelInterceptor` was added to allow opting into CSRF BREACH support.
In Spring Security 6, `XorCsrfChannelInterceptor` is the default `ChannelInterceptor` for making the `CsrfToken` available.
If you configured the `XorCsrfChannelInterceptor` only for the purpose of updating to 6.0, you can remove it completely.

View File

@ -0,0 +1,86 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.messaging.web.csrf;
import java.security.MessageDigest;
import java.util.Map;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.security.crypto.codec.Utf8;
import org.springframework.security.messaging.util.matcher.MessageMatcher;
import org.springframework.security.messaging.util.matcher.SimpMessageTypeMatcher;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.InvalidCsrfTokenException;
import org.springframework.security.web.csrf.MissingCsrfTokenException;
/**
* {@link ChannelInterceptor} that validates a CSRF token masked by the
* {@link org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler} in
* the header of any {@link SimpMessageType#CONNECT} message.
*
* @author Steve Riesenberg
* @since 5.8
*/
public final class XorCsrfChannelInterceptor implements ChannelInterceptor {
private final MessageMatcher<Object> matcher = new SimpMessageTypeMatcher(SimpMessageType.CONNECT);
@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
if (!this.matcher.matches(message)) {
return message;
}
Map<String, Object> sessionAttributes = SimpMessageHeaderAccessor.getSessionAttributes(message.getHeaders());
CsrfToken expectedToken = (sessionAttributes != null)
? (CsrfToken) sessionAttributes.get(CsrfToken.class.getName()) : null;
if (expectedToken == null) {
throw new MissingCsrfTokenException(null);
}
String actualToken = SimpMessageHeaderAccessor.wrap(message)
.getFirstNativeHeader(expectedToken.getHeaderName());
String actualTokenValue = XorCsrfTokenUtils.getTokenValue(actualToken, expectedToken.getToken());
boolean csrfCheckPassed = equalsConstantTime(expectedToken.getToken(), actualTokenValue);
if (!csrfCheckPassed) {
throw new InvalidCsrfTokenException(expectedToken, actualToken);
}
return message;
}
/**
* Constant time comparison to prevent against timing attacks.
* @param expected
* @param actual
* @return
*/
private static boolean equalsConstantTime(String expected, String actual) {
if (expected == actual) {
return true;
}
if (expected == null || actual == null) {
return false;
}
// Encode after ensure that the string is not null
byte[] expectedBytes = Utf8.encode(expected);
byte[] actualBytes = Utf8.encode(actual);
return MessageDigest.isEqual(expectedBytes, actualBytes);
}
}

View File

@ -0,0 +1,72 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.messaging.web.csrf;
import java.util.Base64;
import org.springframework.security.crypto.codec.Utf8;
/**
* Copied from
* {@link org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler}.
*
* @see <a href=
* "https://github.com/spring-projects/spring-security/issues/12378">gh-12378</a>
*/
final class XorCsrfTokenUtils {
private XorCsrfTokenUtils() {
}
static String getTokenValue(String actualToken, String token) {
byte[] actualBytes;
try {
actualBytes = Base64.getUrlDecoder().decode(actualToken);
}
catch (Exception ex) {
return null;
}
byte[] tokenBytes = Utf8.encode(token);
int tokenSize = tokenBytes.length;
if (actualBytes.length < tokenSize) {
return null;
}
// extract token and random bytes
int randomBytesSize = actualBytes.length - tokenSize;
byte[] xoredCsrf = new byte[tokenSize];
byte[] randomBytes = new byte[randomBytesSize];
System.arraycopy(actualBytes, 0, randomBytes, 0, randomBytesSize);
System.arraycopy(actualBytes, randomBytesSize, xoredCsrf, 0, tokenSize);
byte[] csrfBytes = xorCsrf(randomBytes, xoredCsrf);
return Utf8.decode(csrfBytes);
}
private static byte[] xorCsrf(byte[] randomBytes, byte[] csrfBytes) {
int len = Math.min(randomBytes.length, csrfBytes.length);
byte[] xoredCsrf = new byte[len];
System.arraycopy(csrfBytes, 0, xoredCsrf, 0, csrfBytes.length);
for (int i = 0; i < len; i++) {
xoredCsrf[i] ^= randomBytes[i];
}
return xoredCsrf;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2015 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -24,15 +24,18 @@ import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.csrf.DeferredCsrfToken;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
/**
* Copies a CsrfToken from the HttpServletRequest's attributes to the WebSocket
* attributes. This is used as the expected CsrfToken when validating connection requests
* to ensure only the same origin connects.
* Loads a CsrfToken from the HttpServletRequest and HttpServletResponse to populate the
* WebSocket attributes. This is used as the expected CsrfToken when validating connection
* requests to ensure only the same origin connects.
*
* @author Rob Winch
* @author Steve Riesenberg
* @since 4.0
*/
public final class CsrfTokenHandshakeInterceptor implements HandshakeInterceptor {
@ -41,11 +44,19 @@ public final class CsrfTokenHandshakeInterceptor implements HandshakeInterceptor
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler,
Map<String, Object> attributes) {
HttpServletRequest httpRequest = ((ServletServerHttpRequest) request).getServletRequest();
CsrfToken token = (CsrfToken) httpRequest.getAttribute(CsrfToken.class.getName());
if (token == null) {
DeferredCsrfToken deferredCsrfToken = (DeferredCsrfToken) httpRequest
.getAttribute(DeferredCsrfToken.class.getName());
if (deferredCsrfToken == null) {
return true;
}
attributes.put(CsrfToken.class.getName(), token);
CsrfToken csrfToken = deferredCsrfToken.get();
// Ensure the values of the CsrfToken are copied into a new token so the old token
// is available for garbage collection.
// This is required because the original token could hold a reference to the
// HttpServletRequest/Response of the handshake request.
CsrfToken resolvedCsrfToken = new DefaultCsrfToken(csrfToken.getHeaderName(), csrfToken.getParameterName(),
csrfToken.getToken());
attributes.put(CsrfToken.class.getName(), resolvedCsrfToken);
return true;
}

View File

@ -0,0 +1,148 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.messaging.web.csrf;
import java.util.HashMap;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.csrf.InvalidCsrfTokenException;
import org.springframework.security.web.csrf.MissingCsrfTokenException;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link XorCsrfChannelInterceptor}.
*
* @author Steve Riesenberg
*/
public class XorCsrfChannelInterceptorTests {
private static final String XOR_CSRF_TOKEN_VALUE = "wpe7zB62-NCpcA==";
private static final String INVALID_XOR_CSRF_TOKEN_VALUE = "KneoaygbRZtfHQ==";
private CsrfToken token;
private SimpMessageHeaderAccessor messageHeaders;
private MessageChannel channel;
private XorCsrfChannelInterceptor interceptor;
@BeforeEach
public void setup() {
this.token = new DefaultCsrfToken("header", "param", "token");
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
this.messageHeaders.setSessionAttributes(new HashMap<>());
this.channel = mock(MessageChannel.class);
this.interceptor = new XorCsrfChannelInterceptor();
}
@Test
public void preSendWhenConnectWithValidTokenThenSuccess() {
this.messageHeaders.setNativeHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE);
this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token);
this.interceptor.preSend(message(), this.channel);
}
@Test
public void preSendWhenConnectWithInvalidTokenThenThrowsInvalidCsrfTokenException() {
this.messageHeaders.setNativeHeader(this.token.getHeaderName(), INVALID_XOR_CSRF_TOKEN_VALUE);
this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token);
// @formatter:off
assertThatExceptionOfType(InvalidCsrfTokenException.class)
.isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class)));
// @formatter:on
}
@Test
public void preSendWhenConnectWithNoTokenThenThrowsInvalidCsrfTokenException() {
this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token);
// @formatter:off
assertThatExceptionOfType(InvalidCsrfTokenException.class)
.isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class)));
// @formatter:on
}
@Test
public void preSendWhenConnectWithMissingTokenThenThrowsMissingCsrfTokenException() {
// @formatter:off
assertThatExceptionOfType(MissingCsrfTokenException.class)
.isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class)));
// @formatter:on
}
@Test
public void preSendWhenConnectWithNullSessionAttributesThenThrowsMissingCsrfTokenException() {
this.messageHeaders.setSessionAttributes(null);
// @formatter:off
assertThatExceptionOfType(MissingCsrfTokenException.class)
.isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class)));
// @formatter:on
}
@Test
public void preSendWhenAckThenIgnores() {
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
this.interceptor.preSend(message(), this.channel);
}
@Test
public void preSendWhenDisconnectThenIgnores() {
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT);
this.interceptor.preSend(message(), this.channel);
}
@Test
public void preSendWhenHeartbeatThenIgnores() {
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT);
this.interceptor.preSend(message(), this.channel);
}
@Test
public void preSendWhenMessageThenIgnores() {
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
this.interceptor.preSend(message(), this.channel);
}
@Test
public void preSendWhenOtherThenIgnores() {
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.OTHER);
this.interceptor.preSend(message(), this.channel);
}
@Test
public void preSendWhenUnsubscribeThenIgnores() {
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.UNSUBSCRIBE);
this.interceptor.preSend(message(), this.channel);
}
private Message<String> message() {
return MessageBuilder.withPayload("message").copyHeaders(this.messageHeaders.toMap()).build();
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2016 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -31,6 +31,7 @@ import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.csrf.DeferredCsrfToken;
import org.springframework.web.socket.WebSocketHandler;
import static org.assertj.core.api.Assertions.assertThat;
@ -72,10 +73,38 @@ public class CsrfTokenHandshakeInterceptorTests {
@Test
public void beforeHandshake() throws Exception {
CsrfToken token = new DefaultCsrfToken("header", "param", "token");
this.httpRequest.setAttribute(CsrfToken.class.getName(), token);
this.httpRequest.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(token));
this.interceptor.beforeHandshake(this.request, this.response, this.wsHandler, this.attributes);
assertThat(this.attributes.keySet()).containsOnly(CsrfToken.class.getName());
assertThat(this.attributes.values()).containsOnly(token);
CsrfToken csrfToken = (CsrfToken) this.attributes.get(CsrfToken.class.getName());
assertThat(csrfToken.getHeaderName()).isEqualTo(token.getHeaderName());
assertThat(csrfToken.getParameterName()).isEqualTo(token.getParameterName());
assertThat(csrfToken.getToken()).isEqualTo(token.getToken());
// Ensure the values of the CsrfToken are copied into a new token so the old token
// is available for garbage collection.
// This is required because the original token could hold a reference to the
// HttpServletRequest/Response of the handshake request.
assertThat(csrfToken).isNotSameAs(token);
}
private static final class TestDeferredCsrfToken implements DeferredCsrfToken {
private final CsrfToken csrfToken;
private TestDeferredCsrfToken(CsrfToken csrfToken) {
this.csrfToken = csrfToken;
}
@Override
public CsrfToken get() {
return this.csrfToken;
}
@Override
public boolean isGenerated() {
return false;
}
}
}

View File

@ -107,6 +107,7 @@ public final class CsrfFilter extends OncePerRequestFilter {
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
DeferredCsrfToken deferredCsrfToken = this.tokenRepository.loadDeferredToken(request, response);
request.setAttribute(DeferredCsrfToken.class.getName(), deferredCsrfToken);
this.requestHandler.handle(request, response, deferredCsrfToken::get);
if (!this.requireCsrfProtectionMatcher.matches(request)) {
if (this.logger.isTraceEnabled()) {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -127,11 +127,12 @@ public class CsrfFilterTests {
@Test
public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
verifyNoMoreInteractions(this.filterChain);
}
@ -139,12 +140,13 @@ public class CsrfFilterTests {
@Test
public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
verifyNoMoreInteractions(this.filterChain);
}
@ -152,12 +154,13 @@ public class CsrfFilterTests {
@Test
public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
verifyNoMoreInteractions(this.filterChain);
}
@ -166,8 +169,8 @@ public class CsrfFilterTests {
public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter()
throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler();
handler.handle(this.request, this.response, () -> this.token);
CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
@ -176,6 +179,7 @@ public class CsrfFilterTests {
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
verifyNoMoreInteractions(this.filterChain);
}
@ -183,11 +187,12 @@ public class CsrfFilterTests {
@Test
public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verify(this.filterChain).doFilter(this.request, this.response);
verifyNoMoreInteractions(this.deniedHandler);
}
@ -195,11 +200,12 @@ public class CsrfFilterTests {
@Test
public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, true));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verify(this.filterChain).doFilter(this.request, this.response);
verifyNoMoreInteractions(this.deniedHandler);
}
@ -207,8 +213,8 @@ public class CsrfFilterTests {
@Test
public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler();
handler.handle(this.request, this.response, () -> this.token);
CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
@ -216,6 +222,7 @@ public class CsrfFilterTests {
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verify(this.filterChain).doFilter(this.request, this.response);
verifyNoMoreInteractions(this.deniedHandler);
}
@ -224,8 +231,8 @@ public class CsrfFilterTests {
public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam()
throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler();
handler.handle(this.request, this.response, () -> this.token);
CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
@ -234,6 +241,7 @@ public class CsrfFilterTests {
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verify(this.filterChain).doFilter(this.request, this.response);
verifyNoMoreInteractions(this.deniedHandler);
}
@ -241,8 +249,8 @@ public class CsrfFilterTests {
@Test
public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler();
handler.handle(this.request, this.response, () -> this.token);
CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
@ -250,6 +258,7 @@ public class CsrfFilterTests {
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verify(this.filterChain).doFilter(this.request, this.response);
verifyNoMoreInteractions(this.deniedHandler);
verify(this.tokenRepository, never()).saveToken(any(CsrfToken.class), any(HttpServletRequest.class),
@ -259,8 +268,8 @@ public class CsrfFilterTests {
@Test
public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, true));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler();
handler.handle(this.request, this.response, () -> this.token);
CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
@ -268,6 +277,7 @@ public class CsrfFilterTests {
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
// LazyCsrfTokenRepository requires the response as an attribute
assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response);
verify(this.filterChain).doFilter(this.request, this.response);
@ -332,11 +342,12 @@ public class CsrfFilterTests {
this.filter = new CsrfFilter(this.tokenRepository);
this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN);
verifyNoMoreInteractions(this.filterChain);
}
@ -360,22 +371,24 @@ public class CsrfFilterTests {
given(token.getToken()).willReturn(null);
given(token.getHeaderName()).willReturn(this.token.getHeaderName());
given(token.getParameterName()).willReturn(this.token.getParameterName());
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
given(this.requestMatcher.matches(this.request)).willReturn(true);
filter.doFilterInternal(this.request, this.response, this.filterChain);
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
}
@Test
public void doFilterWhenRequestHandlerThenUsed() throws Exception {
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
this.filter = createCsrfFilter(this.tokenRepository);
this.filter.setRequestHandler(requestHandler);
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verify(this.tokenRepository).loadDeferredToken(this.request, this.response);
verify(requestHandler).handle(eq(this.request), eq(this.response), any());
verify(this.filterChain).doFilter(this.request, this.response);
@ -384,11 +397,12 @@ public class CsrfFilterTests {
@Test
public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndValidTokenThenSuccess() throws Exception {
given(this.requestMatcher.matches(this.request)).willReturn(false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
assertThat(this.request.getAttribute("_csrf")).isNotNull();
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verify(this.filterChain).doFilter(this.request, this.response);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
@ -407,10 +421,11 @@ public class CsrfFilterTests {
@Test
public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndRawTokenThenAccessDeniedException() throws Exception {
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(AccessDeniedException.class));
verifyNoMoreInteractions(this.filterChain);
}
@ -435,10 +450,11 @@ public class CsrfFilterTests {
requestHandler.setCsrfRequestAttributeName(csrfAttrName);
filter.setRequestHandler(requestHandler);
CsrfToken expectedCsrfToken = mock(CsrfToken.class);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(expectedCsrfToken, true));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(expectedCsrfToken, true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
filter.doFilter(this.request, this.response, this.filterChain);
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verifyNoInteractions(expectedCsrfToken);
CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName);