diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.java index 1e7b09fe6d..44970da453 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -59,6 +59,8 @@ final class WebSocketMessageBrokerSecurityConfiguration private static final String SIMPLE_URL_HANDLER_MAPPING_BEAN_NAME = "stompWebSocketHandlerMapping"; + private static final String CSRF_CHANNEL_INTERCEPTOR_BEAN_NAME = "csrfChannelInterceptor"; + private MessageMatcherDelegatingAuthorizationManager b; private static final AuthorizationManager> ANY_MESSAGE_AUTHENTICATED = MessageMatcherDelegatingAuthorizationManager @@ -69,7 +71,7 @@ final class WebSocketMessageBrokerSecurityConfiguration private final SecurityContextChannelInterceptor securityContextChannelInterceptor = new SecurityContextChannelInterceptor(); - private final ChannelInterceptor csrfChannelInterceptor = new CsrfChannelInterceptor(); + private ChannelInterceptor csrfChannelInterceptor = new CsrfChannelInterceptor(); private AuthorizationManager> authorizationManager = ANY_MESSAGE_AUTHENTICATED; @@ -90,6 +92,12 @@ final class WebSocketMessageBrokerSecurityConfiguration @Override public void configureClientInboundChannel(ChannelRegistration registration) { + ChannelInterceptor csrfChannelInterceptor = getBeanOrNull(CSRF_CHANNEL_INTERCEPTOR_BEAN_NAME, + ChannelInterceptor.class); + if (csrfChannelInterceptor != null) { + this.csrfChannelInterceptor = csrfChannelInterceptor; + } + AuthorizationManager> manager = this.authorizationManager; if (!this.observationRegistry.isNoop()) { manager = new ObservationAuthorizationManager<>(this.observationRegistry, manager); diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java index 006bae988d..f8f651c782 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -60,6 +60,7 @@ import org.springframework.security.messaging.context.SecurityContextChannelInte import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.security.web.csrf.MissingCsrfTokenException; import org.springframework.stereotype.Controller; import org.springframework.test.util.ReflectionTestUtils; @@ -78,6 +79,7 @@ import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSo import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken; public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { @@ -283,7 +285,7 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { private void assertHandshake(HttpServletRequest request) { TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class); - assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(this.token); + assertThatCsrfToken(handshakeHandler.attributes.get(CsrfToken.class.getName())).isEqualTo(this.token); assertThat(handshakeHandler.attributes.get(this.sessionAttr)) .isEqualTo(request.getSession().getAttribute(this.sessionAttr)); } @@ -305,7 +307,7 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket"); request.setRequestURI(mapping + "/289/tpyx6mde/websocket"); request.getSession().setAttribute(this.sessionAttr, "sessionValue"); - request.setAttribute(CsrfToken.class.getName(), this.token); + request.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token)); return request; } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/socket/TestDeferredCsrfToken.java b/config/src/test/java/org/springframework/security/config/annotation/web/socket/TestDeferredCsrfToken.java new file mode 100644 index 0000000000..11aa0de836 --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/annotation/web/socket/TestDeferredCsrfToken.java @@ -0,0 +1,43 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.web.socket; + +import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.DeferredCsrfToken; + +/** + * @author Steve Riesenberg + */ +final class TestDeferredCsrfToken implements DeferredCsrfToken { + + private final CsrfToken csrfToken; + + TestDeferredCsrfToken(CsrfToken csrfToken) { + this.csrfToken = csrfToken; + } + + @Override + public CsrfToken get() { + return this.csrfToken; + } + + @Override + public boolean isGenerated() { + return false; + } + +} diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java index 6fac21acc3..229c46639a 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -69,6 +69,7 @@ import org.springframework.security.messaging.context.SecurityContextChannelInte import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.security.web.csrf.MissingCsrfTokenException; import org.springframework.stereotype.Controller; import org.springframework.test.util.ReflectionTestUtils; @@ -91,6 +92,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.fail; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.verify; +import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken; public class WebSocketMessageBrokerSecurityConfigurationTests { @@ -366,7 +368,7 @@ public class WebSocketMessageBrokerSecurityConfigurationTests { private void assertHandshake(HttpServletRequest request) { TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class); - assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(this.token); + assertThatCsrfToken(handshakeHandler.attributes.get(CsrfToken.class.getName())).isEqualTo(this.token); assertThat(handshakeHandler.attributes.get(this.sessionAttr)) .isEqualTo(request.getSession().getAttribute(this.sessionAttr)); } @@ -388,7 +390,7 @@ public class WebSocketMessageBrokerSecurityConfigurationTests { request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket"); request.setRequestURI(mapping + "/289/tpyx6mde/websocket"); request.getSession().setAttribute(this.sessionAttr, "sessionValue"); - request.setAttribute(CsrfToken.class.getName(), this.token); + request.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token)); return request; } diff --git a/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java b/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java index f9c71e59fb..c94dcce745 100644 --- a/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -61,6 +61,7 @@ import org.springframework.security.test.context.annotation.SecurityTestExecutio import org.springframework.security.test.context.support.WithMockUser; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.security.web.csrf.InvalidCsrfTokenException; import org.springframework.stereotype.Controller; import org.springframework.test.context.junit.jupiter.SpringExtension; @@ -77,6 +78,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.verify; +import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; /** @@ -381,12 +383,14 @@ public class WebSocketMessageBrokerConfigTests { MockMvc mvc = MockMvcBuilders.webAppContextSetup(context).build(); String csrfAttributeName = CsrfToken.class.getName(); String customAttributeName = this.getClass().getName(); - MvcResult result = mvc.perform(get("/app").requestAttr(csrfAttributeName, this.token) - .sessionAttr(customAttributeName, "attributeValue")).andReturn(); + MvcResult result = mvc.perform( + get("/app").requestAttr(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token)) + .sessionAttr(customAttributeName, "attributeValue")) + .andReturn(); CsrfToken handshakeToken = (CsrfToken) this.testHandshakeHandler.attributes.get(csrfAttributeName); String handshakeValue = (String) this.testHandshakeHandler.attributes.get(customAttributeName); String sessionValue = (String) result.getRequest().getSession().getAttribute(customAttributeName); - assertThat(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated"); + assertThatCsrfToken(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated"); assertThat(handshakeValue).isEqualTo(sessionValue) .withFailMessage("Explicitly listed session variables are not overridden"); } @@ -398,12 +402,13 @@ public class WebSocketMessageBrokerConfigTests { MockMvc mvc = MockMvcBuilders.webAppContextSetup(context).build(); String csrfAttributeName = CsrfToken.class.getName(); String customAttributeName = this.getClass().getName(); - MvcResult result = mvc.perform(get("/app/289/tpyx6mde/websocket").requestAttr(csrfAttributeName, this.token) + MvcResult result = mvc.perform(get("/app/289/tpyx6mde/websocket") + .requestAttr(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token)) .sessionAttr(customAttributeName, "attributeValue")).andReturn(); CsrfToken handshakeToken = (CsrfToken) this.testHandshakeHandler.attributes.get(csrfAttributeName); String handshakeValue = (String) this.testHandshakeHandler.attributes.get(customAttributeName); String sessionValue = (String) result.getRequest().getSession().getAttribute(customAttributeName); - assertThat(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated"); + assertThatCsrfToken(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated"); assertThat(handshakeValue).isEqualTo(sessionValue) .withFailMessage("Explicitly listed session variables are not overridden"); } @@ -526,6 +531,26 @@ public class WebSocketMessageBrokerConfigTests { return SecurityContextHolder.getContextHolderStrategy(); } + private static final class TestDeferredCsrfToken implements DeferredCsrfToken { + + private final CsrfToken csrfToken; + + TestDeferredCsrfToken(CsrfToken csrfToken) { + this.csrfToken = csrfToken; + } + + @Override + public CsrfToken get() { + return this.csrfToken; + } + + @Override + public boolean isGenerated() { + return false; + } + + } + @Controller static class MessageController { diff --git a/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptor.java b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptor.java new file mode 100644 index 0000000000..2d7b3d1c8c --- /dev/null +++ b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptor.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.messaging.web.csrf; + +import java.security.MessageDigest; +import java.util.Map; + +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.SimpMessageType; +import org.springframework.messaging.support.ChannelInterceptor; +import org.springframework.security.crypto.codec.Utf8; +import org.springframework.security.messaging.util.matcher.MessageMatcher; +import org.springframework.security.messaging.util.matcher.SimpMessageTypeMatcher; +import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.InvalidCsrfTokenException; +import org.springframework.security.web.csrf.MissingCsrfTokenException; + +/** + * {@link ChannelInterceptor} that validates a CSRF token masked by the + * {@link org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler} in + * the header of any {@link SimpMessageType#CONNECT} message. + * + * @author Steve Riesenberg + * @since 5.8 + */ +public final class XorCsrfChannelInterceptor implements ChannelInterceptor { + + private final MessageMatcher matcher = new SimpMessageTypeMatcher(SimpMessageType.CONNECT); + + @Override + public Message preSend(Message message, MessageChannel channel) { + if (!this.matcher.matches(message)) { + return message; + } + Map sessionAttributes = SimpMessageHeaderAccessor.getSessionAttributes(message.getHeaders()); + CsrfToken expectedToken = (sessionAttributes != null) + ? (CsrfToken) sessionAttributes.get(CsrfToken.class.getName()) : null; + if (expectedToken == null) { + throw new MissingCsrfTokenException(null); + } + String actualToken = SimpMessageHeaderAccessor.wrap(message) + .getFirstNativeHeader(expectedToken.getHeaderName()); + String actualTokenValue = XorCsrfTokenUtils.getTokenValue(actualToken, expectedToken.getToken()); + boolean csrfCheckPassed = equalsConstantTime(expectedToken.getToken(), actualTokenValue); + if (!csrfCheckPassed) { + throw new InvalidCsrfTokenException(expectedToken, actualToken); + } + return message; + } + + /** + * Constant time comparison to prevent against timing attacks. + * @param expected + * @param actual + * @return + */ + private static boolean equalsConstantTime(String expected, String actual) { + if (expected == actual) { + return true; + } + if (expected == null || actual == null) { + return false; + } + // Encode after ensure that the string is not null + byte[] expectedBytes = Utf8.encode(expected); + byte[] actualBytes = Utf8.encode(actual); + return MessageDigest.isEqual(expectedBytes, actualBytes); + } + +} diff --git a/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java new file mode 100644 index 0000000000..46a67cc4d3 --- /dev/null +++ b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java @@ -0,0 +1,72 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.messaging.web.csrf; + +import java.util.Base64; + +import org.springframework.security.crypto.codec.Utf8; + +/** + * Copied from + * {@link org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler}. + * + * @see gh-12378 + */ +final class XorCsrfTokenUtils { + + private XorCsrfTokenUtils() { + } + + static String getTokenValue(String actualToken, String token) { + byte[] actualBytes; + try { + actualBytes = Base64.getUrlDecoder().decode(actualToken); + } + catch (Exception ex) { + return null; + } + + byte[] tokenBytes = Utf8.encode(token); + int tokenSize = tokenBytes.length; + if (actualBytes.length < tokenSize) { + return null; + } + + // extract token and random bytes + int randomBytesSize = actualBytes.length - tokenSize; + byte[] xoredCsrf = new byte[tokenSize]; + byte[] randomBytes = new byte[randomBytesSize]; + + System.arraycopy(actualBytes, 0, randomBytes, 0, randomBytesSize); + System.arraycopy(actualBytes, randomBytesSize, xoredCsrf, 0, tokenSize); + + byte[] csrfBytes = xorCsrf(randomBytes, xoredCsrf); + return Utf8.decode(csrfBytes); + } + + private static byte[] xorCsrf(byte[] randomBytes, byte[] csrfBytes) { + int len = Math.min(randomBytes.length, csrfBytes.length); + byte[] xoredCsrf = new byte[len]; + System.arraycopy(csrfBytes, 0, xoredCsrf, 0, csrfBytes.length); + for (int i = 0; i < len; i++) { + xoredCsrf[i] ^= randomBytes[i]; + } + return xoredCsrf; + } + +} diff --git a/messaging/src/main/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptor.java b/messaging/src/main/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptor.java index aa936c91fe..3c4a785f36 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptor.java +++ b/messaging/src/main/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2015 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,15 +24,18 @@ import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.server.HandshakeInterceptor; /** - * Copies a CsrfToken from the HttpServletRequest's attributes to the WebSocket - * attributes. This is used as the expected CsrfToken when validating connection requests - * to ensure only the same origin connects. + * Loads a CsrfToken from the HttpServletRequest and HttpServletResponse to populate the + * WebSocket attributes. This is used as the expected CsrfToken when validating connection + * requests to ensure only the same origin connects. * * @author Rob Winch + * @author Steve Riesenberg * @since 4.0 */ public final class CsrfTokenHandshakeInterceptor implements HandshakeInterceptor { @@ -41,11 +44,19 @@ public final class CsrfTokenHandshakeInterceptor implements HandshakeInterceptor public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes) { HttpServletRequest httpRequest = ((ServletServerHttpRequest) request).getServletRequest(); - CsrfToken token = (CsrfToken) httpRequest.getAttribute(CsrfToken.class.getName()); - if (token == null) { + DeferredCsrfToken deferredCsrfToken = (DeferredCsrfToken) httpRequest + .getAttribute(DeferredCsrfToken.class.getName()); + if (deferredCsrfToken == null) { return true; } - attributes.put(CsrfToken.class.getName(), token); + CsrfToken csrfToken = deferredCsrfToken.get(); + // Ensure the values of the CsrfToken are copied into a new token so the old token + // is available for garbage collection. + // This is required because the original token could hold a reference to the + // HttpServletRequest/Response of the handshake request. + CsrfToken resolvedCsrfToken = new DefaultCsrfToken(csrfToken.getHeaderName(), csrfToken.getParameterName(), + csrfToken.getToken()); + attributes.put(CsrfToken.class.getName(), resolvedCsrfToken); return true; } diff --git a/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java b/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java new file mode 100644 index 0000000000..884c3d2fc2 --- /dev/null +++ b/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java @@ -0,0 +1,148 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.messaging.web.csrf; + +import java.util.HashMap; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.SimpMessageType; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.security.web.csrf.InvalidCsrfTokenException; +import org.springframework.security.web.csrf.MissingCsrfTokenException; + +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link XorCsrfChannelInterceptor}. + * + * @author Steve Riesenberg + */ +public class XorCsrfChannelInterceptorTests { + + private static final String XOR_CSRF_TOKEN_VALUE = "wpe7zB62-NCpcA=="; + + private static final String INVALID_XOR_CSRF_TOKEN_VALUE = "KneoaygbRZtfHQ=="; + + private CsrfToken token; + + private SimpMessageHeaderAccessor messageHeaders; + + private MessageChannel channel; + + private XorCsrfChannelInterceptor interceptor; + + @BeforeEach + public void setup() { + this.token = new DefaultCsrfToken("header", "param", "token"); + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); + this.messageHeaders.setSessionAttributes(new HashMap<>()); + this.channel = mock(MessageChannel.class); + this.interceptor = new XorCsrfChannelInterceptor(); + } + + @Test + public void preSendWhenConnectWithValidTokenThenSuccess() { + this.messageHeaders.setNativeHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE); + this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token); + this.interceptor.preSend(message(), this.channel); + } + + @Test + public void preSendWhenConnectWithInvalidTokenThenThrowsInvalidCsrfTokenException() { + this.messageHeaders.setNativeHeader(this.token.getHeaderName(), INVALID_XOR_CSRF_TOKEN_VALUE); + this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token); + // @formatter:off + assertThatExceptionOfType(InvalidCsrfTokenException.class) + .isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class))); + // @formatter:on + } + + @Test + public void preSendWhenConnectWithNoTokenThenThrowsInvalidCsrfTokenException() { + this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token); + // @formatter:off + assertThatExceptionOfType(InvalidCsrfTokenException.class) + .isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class))); + // @formatter:on + } + + @Test + public void preSendWhenConnectWithMissingTokenThenThrowsMissingCsrfTokenException() { + // @formatter:off + assertThatExceptionOfType(MissingCsrfTokenException.class) + .isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class))); + // @formatter:on + } + + @Test + public void preSendWhenConnectWithNullSessionAttributesThenThrowsMissingCsrfTokenException() { + this.messageHeaders.setSessionAttributes(null); + // @formatter:off + assertThatExceptionOfType(MissingCsrfTokenException.class) + .isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class))); + // @formatter:on + } + + @Test + public void preSendWhenAckThenIgnores() { + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK); + this.interceptor.preSend(message(), this.channel); + } + + @Test + public void preSendWhenDisconnectThenIgnores() { + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT); + this.interceptor.preSend(message(), this.channel); + } + + @Test + public void preSendWhenHeartbeatThenIgnores() { + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT); + this.interceptor.preSend(message(), this.channel); + } + + @Test + public void preSendWhenMessageThenIgnores() { + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); + this.interceptor.preSend(message(), this.channel); + } + + @Test + public void preSendWhenOtherThenIgnores() { + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.OTHER); + this.interceptor.preSend(message(), this.channel); + } + + @Test + public void preSendWhenUnsubscribeThenIgnores() { + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.UNSUBSCRIBE); + this.interceptor.preSend(message(), this.channel); + } + + private Message message() { + return MessageBuilder.withPayload("message").copyHeaders(this.messageHeaders.toMap()).build(); + } + +} diff --git a/messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java b/messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java index a3c38d14d9..760fe3aa1b 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,6 +31,7 @@ import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.web.socket.WebSocketHandler; import static org.assertj.core.api.Assertions.assertThat; @@ -72,10 +73,38 @@ public class CsrfTokenHandshakeInterceptorTests { @Test public void beforeHandshake() throws Exception { CsrfToken token = new DefaultCsrfToken("header", "param", "token"); - this.httpRequest.setAttribute(CsrfToken.class.getName(), token); + this.httpRequest.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(token)); this.interceptor.beforeHandshake(this.request, this.response, this.wsHandler, this.attributes); assertThat(this.attributes.keySet()).containsOnly(CsrfToken.class.getName()); - assertThat(this.attributes.values()).containsOnly(token); + CsrfToken csrfToken = (CsrfToken) this.attributes.get(CsrfToken.class.getName()); + assertThat(csrfToken.getHeaderName()).isEqualTo(token.getHeaderName()); + assertThat(csrfToken.getParameterName()).isEqualTo(token.getParameterName()); + assertThat(csrfToken.getToken()).isEqualTo(token.getToken()); + // Ensure the values of the CsrfToken are copied into a new token so the old token + // is available for garbage collection. + // This is required because the original token could hold a reference to the + // HttpServletRequest/Response of the handshake request. + assertThat(csrfToken).isNotSameAs(token); + } + + private static final class TestDeferredCsrfToken implements DeferredCsrfToken { + + private final CsrfToken csrfToken; + + private TestDeferredCsrfToken(CsrfToken csrfToken) { + this.csrfToken = csrfToken; + } + + @Override + public CsrfToken get() { + return this.csrfToken; + } + + @Override + public boolean isGenerated() { + return false; + } + } } diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java index 06afa7747d..ccce4f4c57 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java @@ -107,6 +107,7 @@ public final class CsrfFilter extends OncePerRequestFilter { protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { DeferredCsrfToken deferredCsrfToken = this.tokenRepository.loadDeferredToken(request, response); + request.setAttribute(DeferredCsrfToken.class.getName(), deferredCsrfToken); this.requestHandler.handle(request, response, deferredCsrfToken::get); if (!this.requireCsrfProtectionMatcher.matches(request)) { if (this.logger.isTraceEnabled()) { diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java index c57e1bfc7d..96dba3f92b 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -127,11 +127,12 @@ public class CsrfFilterTests { @Test public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull(); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -139,12 +140,13 @@ public class CsrfFilterTests { @Test public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull(); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -152,12 +154,13 @@ public class CsrfFilterTests { @Test public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull(); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -166,8 +169,8 @@ public class CsrfFilterTests { public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler(); handler.handle(this.request, this.response, () -> this.token); CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName()); @@ -176,6 +179,7 @@ public class CsrfFilterTests { this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull(); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -183,11 +187,12 @@ public class CsrfFilterTests { @Test public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(false); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull(); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @@ -195,11 +200,12 @@ public class CsrfFilterTests { @Test public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(false); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, true)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, true); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull(); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @@ -207,8 +213,8 @@ public class CsrfFilterTests { @Test public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler(); handler.handle(this.request, this.response, () -> this.token); CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName()); @@ -216,6 +222,7 @@ public class CsrfFilterTests { this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull(); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @@ -224,8 +231,8 @@ public class CsrfFilterTests { public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler(); handler.handle(this.request, this.response, () -> this.token); CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName()); @@ -234,6 +241,7 @@ public class CsrfFilterTests { this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull(); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @@ -241,8 +249,8 @@ public class CsrfFilterTests { @Test public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler(); handler.handle(this.request, this.response, () -> this.token); CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName()); @@ -250,6 +258,7 @@ public class CsrfFilterTests { this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull(); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); verify(this.tokenRepository, never()).saveToken(any(CsrfToken.class), any(HttpServletRequest.class), @@ -259,8 +268,8 @@ public class CsrfFilterTests { @Test public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, true)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, true); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler(); handler.handle(this.request, this.response, () -> this.token); CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName()); @@ -268,6 +277,7 @@ public class CsrfFilterTests { this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull(); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); // LazyCsrfTokenRepository requires the response as an attribute assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response); verify(this.filterChain).doFilter(this.request, this.response); @@ -332,11 +342,12 @@ public class CsrfFilterTests { this.filter = new CsrfFilter(this.tokenRepository); this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher); given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull(); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); verifyNoMoreInteractions(this.filterChain); } @@ -360,22 +371,24 @@ public class CsrfFilterTests { given(token.getToken()).willReturn(null); given(token.getHeaderName()).willReturn(this.token.getHeaderName()); given(token.getParameterName()).willReturn(this.token.getParameterName()); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); given(this.requestMatcher.matches(this.request)).willReturn(true); filter.doFilterInternal(this.request, this.response, this.filterChain); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); } @Test public void doFilterWhenRequestHandlerThenUsed() throws Exception { - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class); this.filter = createCsrfFilter(this.tokenRepository); this.filter.setRequestHandler(requestHandler); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.tokenRepository).loadDeferredToken(this.request, this.response); verify(requestHandler).handle(eq(this.request), eq(this.response), any()); verify(this.filterChain).doFilter(this.request, this.response); @@ -384,11 +397,12 @@ public class CsrfFilterTests { @Test public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndValidTokenThenSuccess() throws Exception { given(this.requestMatcher.matches(this.request)).willReturn(false); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.filter.doFilter(this.request, this.response, this.filterChain); assertThat(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); assertThat(this.request.getAttribute("_csrf")).isNotNull(); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.filterChain).doFilter(this.request, this.response); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); @@ -407,10 +421,11 @@ public class CsrfFilterTests { @Test public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndRawTokenThenAccessDeniedException() throws Exception { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(AccessDeniedException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -435,10 +450,11 @@ public class CsrfFilterTests { requestHandler.setCsrfRequestAttributeName(csrfAttrName); filter.setRequestHandler(requestHandler); CsrfToken expectedCsrfToken = mock(CsrfToken.class); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(expectedCsrfToken, true)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(expectedCsrfToken, true); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); filter.doFilter(this.request, this.response, this.filterChain); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verifyNoInteractions(expectedCsrfToken); CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName);