From c306df9b46d9803ddf1020c24ea88317fde7a037 Mon Sep 17 00:00:00 2001 From: Steve Riesenberg Date: Mon, 9 Jan 2023 15:58:28 -0600 Subject: [PATCH 1/4] Add XorCsrfChannelInterceptor Issue gh-12378 --- ...ketMessageBrokerSecurityConfiguration.java | 12 +- ...WebSocketMessageBrokerConfigurerTests.java | 8 +- .../web/socket/TestDeferredCsrfToken.java | 43 +++++ ...ssageBrokerSecurityConfigurationTests.java | 8 +- .../WebSocketMessageBrokerConfigTests.java | 37 ++++- .../web/csrf/XorCsrfChannelInterceptor.java | 86 ++++++++++ .../messaging/web/csrf/XorCsrfTokenUtils.java | 72 +++++++++ .../server/CsrfTokenHandshakeInterceptor.java | 25 ++- .../csrf/XorCsrfChannelInterceptorTests.java | 148 ++++++++++++++++++ .../CsrfTokenHandshakeInterceptorTests.java | 35 ++++- .../security/web/csrf/CsrfFilter.java | 3 +- .../security/web/csrf/CsrfFilterTests.java | 82 ++++++---- 12 files changed, 501 insertions(+), 58 deletions(-) create mode 100644 config/src/test/java/org/springframework/security/config/annotation/web/socket/TestDeferredCsrfToken.java create mode 100644 messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptor.java create mode 100644 messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java create mode 100644 messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.java index 8dfc4da381..1a23f8d5ec 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -56,6 +56,8 @@ final class WebSocketMessageBrokerSecurityConfiguration private static final String SIMPLE_URL_HANDLER_MAPPING_BEAN_NAME = "stompWebSocketHandlerMapping"; + private static final String CSRF_CHANNEL_INTERCEPTOR_BEAN_NAME = "csrfChannelInterceptor"; + private MessageMatcherDelegatingAuthorizationManager b; private static final AuthorizationManager> ANY_MESSAGE_AUTHENTICATED = MessageMatcherDelegatingAuthorizationManager @@ -66,7 +68,7 @@ final class WebSocketMessageBrokerSecurityConfiguration private final SecurityContextChannelInterceptor securityContextChannelInterceptor = new SecurityContextChannelInterceptor(); - private final ChannelInterceptor csrfChannelInterceptor = new CsrfChannelInterceptor(); + private ChannelInterceptor csrfChannelInterceptor = new CsrfChannelInterceptor(); private AuthorizationChannelInterceptor authorizationChannelInterceptor = new AuthorizationChannelInterceptor( ANY_MESSAGE_AUTHENTICATED); @@ -86,6 +88,12 @@ final class WebSocketMessageBrokerSecurityConfiguration @Override public void configureClientInboundChannel(ChannelRegistration registration) { + ChannelInterceptor csrfChannelInterceptor = getBeanOrNull(CSRF_CHANNEL_INTERCEPTOR_BEAN_NAME, + ChannelInterceptor.class); + if (csrfChannelInterceptor != null) { + this.csrfChannelInterceptor = csrfChannelInterceptor; + } + this.authorizationChannelInterceptor .setAuthorizationEventPublisher(new SpringAuthorizationEventPublisher(this.context)); this.authorizationChannelInterceptor.setSecurityContextHolderStrategy(this.securityContextHolderStrategy); diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java index 8d0ad84835..2286ff6f60 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -61,6 +61,7 @@ import org.springframework.security.messaging.context.SecurityContextChannelInte import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.security.web.csrf.MissingCsrfTokenException; import org.springframework.stereotype.Controller; import org.springframework.test.util.ReflectionTestUtils; @@ -79,6 +80,7 @@ import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSo import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken; public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { @@ -284,7 +286,7 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { private void assertHandshake(HttpServletRequest request) { TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class); - assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(this.token); + assertThatCsrfToken(handshakeHandler.attributes.get(CsrfToken.class.getName())).isEqualTo(this.token); assertThat(handshakeHandler.attributes.get(this.sessionAttr)) .isEqualTo(request.getSession().getAttribute(this.sessionAttr)); } @@ -306,7 +308,7 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket"); request.setRequestURI(mapping + "/289/tpyx6mde/websocket"); request.getSession().setAttribute(this.sessionAttr, "sessionValue"); - request.setAttribute(CsrfToken.class.getName(), this.token); + request.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token)); return request; } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/socket/TestDeferredCsrfToken.java b/config/src/test/java/org/springframework/security/config/annotation/web/socket/TestDeferredCsrfToken.java new file mode 100644 index 0000000000..11aa0de836 --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/annotation/web/socket/TestDeferredCsrfToken.java @@ -0,0 +1,43 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.web.socket; + +import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.DeferredCsrfToken; + +/** + * @author Steve Riesenberg + */ +final class TestDeferredCsrfToken implements DeferredCsrfToken { + + private final CsrfToken csrfToken; + + TestDeferredCsrfToken(CsrfToken csrfToken) { + this.csrfToken = csrfToken; + } + + @Override + public CsrfToken get() { + return this.csrfToken; + } + + @Override + public boolean isGenerated() { + return false; + } + +} diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java index 8823fe39cb..6977207218 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -70,6 +70,7 @@ import org.springframework.security.messaging.context.SecurityContextChannelInte import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.security.web.csrf.MissingCsrfTokenException; import org.springframework.stereotype.Controller; import org.springframework.test.util.ReflectionTestUtils; @@ -92,6 +93,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.fail; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.verify; +import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken; public class WebSocketMessageBrokerSecurityConfigurationTests { @@ -367,7 +369,7 @@ public class WebSocketMessageBrokerSecurityConfigurationTests { private void assertHandshake(HttpServletRequest request) { TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class); - assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(this.token); + assertThatCsrfToken(handshakeHandler.attributes.get(CsrfToken.class.getName())).isEqualTo(this.token); assertThat(handshakeHandler.attributes.get(this.sessionAttr)) .isEqualTo(request.getSession().getAttribute(this.sessionAttr)); } @@ -389,7 +391,7 @@ public class WebSocketMessageBrokerSecurityConfigurationTests { request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket"); request.setRequestURI(mapping + "/289/tpyx6mde/websocket"); request.getSession().setAttribute(this.sessionAttr, "sessionValue"); - request.setAttribute(CsrfToken.class.getName(), this.token); + request.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token)); return request; } diff --git a/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java b/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java index d39e2c9371..f9243db209 100644 --- a/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -61,6 +61,7 @@ import org.springframework.security.test.context.annotation.SecurityTestExecutio import org.springframework.security.test.context.support.WithMockUser; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.security.web.csrf.InvalidCsrfTokenException; import org.springframework.stereotype.Controller; import org.springframework.test.context.junit.jupiter.SpringExtension; @@ -77,6 +78,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.verify; +import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; /** @@ -381,12 +383,14 @@ public class WebSocketMessageBrokerConfigTests { MockMvc mvc = MockMvcBuilders.webAppContextSetup(context).build(); String csrfAttributeName = CsrfToken.class.getName(); String customAttributeName = this.getClass().getName(); - MvcResult result = mvc.perform(get("/app").requestAttr(csrfAttributeName, this.token) - .sessionAttr(customAttributeName, "attributeValue")).andReturn(); + MvcResult result = mvc.perform( + get("/app").requestAttr(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token)) + .sessionAttr(customAttributeName, "attributeValue")) + .andReturn(); CsrfToken handshakeToken = (CsrfToken) this.testHandshakeHandler.attributes.get(csrfAttributeName); String handshakeValue = (String) this.testHandshakeHandler.attributes.get(customAttributeName); String sessionValue = (String) result.getRequest().getSession().getAttribute(customAttributeName); - assertThat(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated"); + assertThatCsrfToken(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated"); assertThat(handshakeValue).isEqualTo(sessionValue) .withFailMessage("Explicitly listed session variables are not overridden"); } @@ -398,12 +402,13 @@ public class WebSocketMessageBrokerConfigTests { MockMvc mvc = MockMvcBuilders.webAppContextSetup(context).build(); String csrfAttributeName = CsrfToken.class.getName(); String customAttributeName = this.getClass().getName(); - MvcResult result = mvc.perform(get("/app/289/tpyx6mde/websocket").requestAttr(csrfAttributeName, this.token) + MvcResult result = mvc.perform(get("/app/289/tpyx6mde/websocket") + .requestAttr(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token)) .sessionAttr(customAttributeName, "attributeValue")).andReturn(); CsrfToken handshakeToken = (CsrfToken) this.testHandshakeHandler.attributes.get(csrfAttributeName); String handshakeValue = (String) this.testHandshakeHandler.attributes.get(customAttributeName); String sessionValue = (String) result.getRequest().getSession().getAttribute(customAttributeName); - assertThat(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated"); + assertThatCsrfToken(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated"); assertThat(handshakeValue).isEqualTo(sessionValue) .withFailMessage("Explicitly listed session variables are not overridden"); } @@ -526,6 +531,26 @@ public class WebSocketMessageBrokerConfigTests { return SecurityContextHolder.getContextHolderStrategy(); } + private static final class TestDeferredCsrfToken implements DeferredCsrfToken { + + private final CsrfToken csrfToken; + + TestDeferredCsrfToken(CsrfToken csrfToken) { + this.csrfToken = csrfToken; + } + + @Override + public CsrfToken get() { + return this.csrfToken; + } + + @Override + public boolean isGenerated() { + return false; + } + + } + @Controller static class MessageController { diff --git a/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptor.java b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptor.java new file mode 100644 index 0000000000..2d7b3d1c8c --- /dev/null +++ b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptor.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.messaging.web.csrf; + +import java.security.MessageDigest; +import java.util.Map; + +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.SimpMessageType; +import org.springframework.messaging.support.ChannelInterceptor; +import org.springframework.security.crypto.codec.Utf8; +import org.springframework.security.messaging.util.matcher.MessageMatcher; +import org.springframework.security.messaging.util.matcher.SimpMessageTypeMatcher; +import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.InvalidCsrfTokenException; +import org.springframework.security.web.csrf.MissingCsrfTokenException; + +/** + * {@link ChannelInterceptor} that validates a CSRF token masked by the + * {@link org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler} in + * the header of any {@link SimpMessageType#CONNECT} message. + * + * @author Steve Riesenberg + * @since 5.8 + */ +public final class XorCsrfChannelInterceptor implements ChannelInterceptor { + + private final MessageMatcher matcher = new SimpMessageTypeMatcher(SimpMessageType.CONNECT); + + @Override + public Message preSend(Message message, MessageChannel channel) { + if (!this.matcher.matches(message)) { + return message; + } + Map sessionAttributes = SimpMessageHeaderAccessor.getSessionAttributes(message.getHeaders()); + CsrfToken expectedToken = (sessionAttributes != null) + ? (CsrfToken) sessionAttributes.get(CsrfToken.class.getName()) : null; + if (expectedToken == null) { + throw new MissingCsrfTokenException(null); + } + String actualToken = SimpMessageHeaderAccessor.wrap(message) + .getFirstNativeHeader(expectedToken.getHeaderName()); + String actualTokenValue = XorCsrfTokenUtils.getTokenValue(actualToken, expectedToken.getToken()); + boolean csrfCheckPassed = equalsConstantTime(expectedToken.getToken(), actualTokenValue); + if (!csrfCheckPassed) { + throw new InvalidCsrfTokenException(expectedToken, actualToken); + } + return message; + } + + /** + * Constant time comparison to prevent against timing attacks. + * @param expected + * @param actual + * @return + */ + private static boolean equalsConstantTime(String expected, String actual) { + if (expected == actual) { + return true; + } + if (expected == null || actual == null) { + return false; + } + // Encode after ensure that the string is not null + byte[] expectedBytes = Utf8.encode(expected); + byte[] actualBytes = Utf8.encode(actual); + return MessageDigest.isEqual(expectedBytes, actualBytes); + } + +} diff --git a/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java new file mode 100644 index 0000000000..46a67cc4d3 --- /dev/null +++ b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java @@ -0,0 +1,72 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.messaging.web.csrf; + +import java.util.Base64; + +import org.springframework.security.crypto.codec.Utf8; + +/** + * Copied from + * {@link org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler}. + * + * @see gh-12378 + */ +final class XorCsrfTokenUtils { + + private XorCsrfTokenUtils() { + } + + static String getTokenValue(String actualToken, String token) { + byte[] actualBytes; + try { + actualBytes = Base64.getUrlDecoder().decode(actualToken); + } + catch (Exception ex) { + return null; + } + + byte[] tokenBytes = Utf8.encode(token); + int tokenSize = tokenBytes.length; + if (actualBytes.length < tokenSize) { + return null; + } + + // extract token and random bytes + int randomBytesSize = actualBytes.length - tokenSize; + byte[] xoredCsrf = new byte[tokenSize]; + byte[] randomBytes = new byte[randomBytesSize]; + + System.arraycopy(actualBytes, 0, randomBytes, 0, randomBytesSize); + System.arraycopy(actualBytes, randomBytesSize, xoredCsrf, 0, tokenSize); + + byte[] csrfBytes = xorCsrf(randomBytes, xoredCsrf); + return Utf8.decode(csrfBytes); + } + + private static byte[] xorCsrf(byte[] randomBytes, byte[] csrfBytes) { + int len = Math.min(randomBytes.length, csrfBytes.length); + byte[] xoredCsrf = new byte[len]; + System.arraycopy(csrfBytes, 0, xoredCsrf, 0, csrfBytes.length); + for (int i = 0; i < len; i++) { + xoredCsrf[i] ^= randomBytes[i]; + } + return xoredCsrf; + } + +} diff --git a/messaging/src/main/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptor.java b/messaging/src/main/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptor.java index aa40975f2f..1c917d82ee 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptor.java +++ b/messaging/src/main/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2015 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,15 +24,18 @@ import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.server.HandshakeInterceptor; /** - * Copies a CsrfToken from the HttpServletRequest's attributes to the WebSocket - * attributes. This is used as the expected CsrfToken when validating connection requests - * to ensure only the same origin connects. + * Loads a CsrfToken from the HttpServletRequest and HttpServletResponse to populate the + * WebSocket attributes. This is used as the expected CsrfToken when validating connection + * requests to ensure only the same origin connects. * * @author Rob Winch + * @author Steve Riesenberg * @since 4.0 */ public final class CsrfTokenHandshakeInterceptor implements HandshakeInterceptor { @@ -41,11 +44,19 @@ public final class CsrfTokenHandshakeInterceptor implements HandshakeInterceptor public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes) { HttpServletRequest httpRequest = ((ServletServerHttpRequest) request).getServletRequest(); - CsrfToken token = (CsrfToken) httpRequest.getAttribute(CsrfToken.class.getName()); - if (token == null) { + DeferredCsrfToken deferredCsrfToken = (DeferredCsrfToken) httpRequest + .getAttribute(DeferredCsrfToken.class.getName()); + if (deferredCsrfToken == null) { return true; } - attributes.put(CsrfToken.class.getName(), token); + CsrfToken csrfToken = deferredCsrfToken.get(); + // Ensure the values of the CsrfToken are copied into a new token so the old token + // is available for garbage collection. + // This is required because the original token could hold a reference to the + // HttpServletRequest/Response of the handshake request. + CsrfToken resolvedCsrfToken = new DefaultCsrfToken(csrfToken.getHeaderName(), csrfToken.getParameterName(), + csrfToken.getToken()); + attributes.put(CsrfToken.class.getName(), resolvedCsrfToken); return true; } diff --git a/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java b/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java new file mode 100644 index 0000000000..884c3d2fc2 --- /dev/null +++ b/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java @@ -0,0 +1,148 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.messaging.web.csrf; + +import java.util.HashMap; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.SimpMessageType; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.security.web.csrf.InvalidCsrfTokenException; +import org.springframework.security.web.csrf.MissingCsrfTokenException; + +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link XorCsrfChannelInterceptor}. + * + * @author Steve Riesenberg + */ +public class XorCsrfChannelInterceptorTests { + + private static final String XOR_CSRF_TOKEN_VALUE = "wpe7zB62-NCpcA=="; + + private static final String INVALID_XOR_CSRF_TOKEN_VALUE = "KneoaygbRZtfHQ=="; + + private CsrfToken token; + + private SimpMessageHeaderAccessor messageHeaders; + + private MessageChannel channel; + + private XorCsrfChannelInterceptor interceptor; + + @BeforeEach + public void setup() { + this.token = new DefaultCsrfToken("header", "param", "token"); + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); + this.messageHeaders.setSessionAttributes(new HashMap<>()); + this.channel = mock(MessageChannel.class); + this.interceptor = new XorCsrfChannelInterceptor(); + } + + @Test + public void preSendWhenConnectWithValidTokenThenSuccess() { + this.messageHeaders.setNativeHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE); + this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token); + this.interceptor.preSend(message(), this.channel); + } + + @Test + public void preSendWhenConnectWithInvalidTokenThenThrowsInvalidCsrfTokenException() { + this.messageHeaders.setNativeHeader(this.token.getHeaderName(), INVALID_XOR_CSRF_TOKEN_VALUE); + this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token); + // @formatter:off + assertThatExceptionOfType(InvalidCsrfTokenException.class) + .isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class))); + // @formatter:on + } + + @Test + public void preSendWhenConnectWithNoTokenThenThrowsInvalidCsrfTokenException() { + this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token); + // @formatter:off + assertThatExceptionOfType(InvalidCsrfTokenException.class) + .isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class))); + // @formatter:on + } + + @Test + public void preSendWhenConnectWithMissingTokenThenThrowsMissingCsrfTokenException() { + // @formatter:off + assertThatExceptionOfType(MissingCsrfTokenException.class) + .isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class))); + // @formatter:on + } + + @Test + public void preSendWhenConnectWithNullSessionAttributesThenThrowsMissingCsrfTokenException() { + this.messageHeaders.setSessionAttributes(null); + // @formatter:off + assertThatExceptionOfType(MissingCsrfTokenException.class) + .isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class))); + // @formatter:on + } + + @Test + public void preSendWhenAckThenIgnores() { + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK); + this.interceptor.preSend(message(), this.channel); + } + + @Test + public void preSendWhenDisconnectThenIgnores() { + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT); + this.interceptor.preSend(message(), this.channel); + } + + @Test + public void preSendWhenHeartbeatThenIgnores() { + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT); + this.interceptor.preSend(message(), this.channel); + } + + @Test + public void preSendWhenMessageThenIgnores() { + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); + this.interceptor.preSend(message(), this.channel); + } + + @Test + public void preSendWhenOtherThenIgnores() { + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.OTHER); + this.interceptor.preSend(message(), this.channel); + } + + @Test + public void preSendWhenUnsubscribeThenIgnores() { + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.UNSUBSCRIBE); + this.interceptor.preSend(message(), this.channel); + } + + private Message message() { + return MessageBuilder.withPayload("message").copyHeaders(this.messageHeaders.toMap()).build(); + } + +} diff --git a/messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java b/messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java index a3c38d14d9..760fe3aa1b 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,6 +31,7 @@ import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.web.socket.WebSocketHandler; import static org.assertj.core.api.Assertions.assertThat; @@ -72,10 +73,38 @@ public class CsrfTokenHandshakeInterceptorTests { @Test public void beforeHandshake() throws Exception { CsrfToken token = new DefaultCsrfToken("header", "param", "token"); - this.httpRequest.setAttribute(CsrfToken.class.getName(), token); + this.httpRequest.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(token)); this.interceptor.beforeHandshake(this.request, this.response, this.wsHandler, this.attributes); assertThat(this.attributes.keySet()).containsOnly(CsrfToken.class.getName()); - assertThat(this.attributes.values()).containsOnly(token); + CsrfToken csrfToken = (CsrfToken) this.attributes.get(CsrfToken.class.getName()); + assertThat(csrfToken.getHeaderName()).isEqualTo(token.getHeaderName()); + assertThat(csrfToken.getParameterName()).isEqualTo(token.getParameterName()); + assertThat(csrfToken.getToken()).isEqualTo(token.getToken()); + // Ensure the values of the CsrfToken are copied into a new token so the old token + // is available for garbage collection. + // This is required because the original token could hold a reference to the + // HttpServletRequest/Response of the handshake request. + assertThat(csrfToken).isNotSameAs(token); + } + + private static final class TestDeferredCsrfToken implements DeferredCsrfToken { + + private final CsrfToken csrfToken; + + private TestDeferredCsrfToken(CsrfToken csrfToken) { + this.csrfToken = csrfToken; + } + + @Override + public CsrfToken get() { + return this.csrfToken; + } + + @Override + public boolean isGenerated() { + return false; + } + } } diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java index 5f3b94b6c9..3f966832a4 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -108,6 +108,7 @@ public final class CsrfFilter extends OncePerRequestFilter { protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { DeferredCsrfToken deferredCsrfToken = this.tokenRepository.loadDeferredToken(request, response); + request.setAttribute(DeferredCsrfToken.class.getName(), deferredCsrfToken); this.requestHandler.handle(request, response, deferredCsrfToken::get); if (!this.requireCsrfProtectionMatcher.matches(request)) { if (this.logger.isTraceEnabled()) { diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java index 4ad810329a..68875e05a8 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -126,11 +126,12 @@ public class CsrfFilterTests { @Test public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -138,12 +139,13 @@ public class CsrfFilterTests { @Test public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -151,12 +153,13 @@ public class CsrfFilterTests { @Test public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -165,13 +168,14 @@ public class CsrfFilterTests { public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -179,11 +183,12 @@ public class CsrfFilterTests { @Test public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(false); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @@ -191,11 +196,12 @@ public class CsrfFilterTests { @Test public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(false); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, true)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, true); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @@ -203,12 +209,13 @@ public class CsrfFilterTests { @Test public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @@ -217,13 +224,14 @@ public class CsrfFilterTests { public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @@ -231,12 +239,13 @@ public class CsrfFilterTests { @Test public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); verify(this.tokenRepository, never()).saveToken(any(CsrfToken.class), any(HttpServletRequest.class), @@ -246,12 +255,13 @@ public class CsrfFilterTests { @Test public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, true)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, true); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); // LazyCsrfTokenRepository requires the response as an attribute assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response); verify(this.filterChain).doFilter(this.request, this.response); @@ -316,11 +326,12 @@ public class CsrfFilterTests { this.filter = new CsrfFilter(this.tokenRepository); this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher); given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); verifyNoMoreInteractions(this.filterChain); } @@ -344,22 +355,24 @@ public class CsrfFilterTests { given(token.getToken()).willReturn(null); given(token.getHeaderName()).willReturn(this.token.getHeaderName()); given(token.getParameterName()).willReturn(this.token.getParameterName()); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); given(this.requestMatcher.matches(this.request)).willReturn(true); filter.doFilterInternal(this.request, this.response, this.filterChain); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); } @Test public void doFilterWhenRequestHandlerThenUsed() throws Exception { - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class); this.filter = createCsrfFilter(this.tokenRepository); this.filter.setRequestHandler(requestHandler); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.tokenRepository).loadDeferredToken(this.request, this.response); verify(requestHandler).handle(eq(this.request), eq(this.response), any()); verify(this.filterChain).doFilter(this.request, this.response); @@ -368,14 +381,15 @@ public class CsrfFilterTests { @Test public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndValidTokenThenSuccess() throws Exception { given(this.requestMatcher.matches(this.request)).willReturn(false); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); XorCsrfTokenRequestAttributeHandler requestHandler = new XorCsrfTokenRequestAttributeHandler(); requestHandler.setCsrfRequestAttributeName(this.token.getParameterName()); this.filter.setRequestHandler(requestHandler); this.filter.doFilter(this.request, this.response, this.filterChain); assertThat(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); assertThat(this.request.getAttribute(this.token.getParameterName())).isNotNull(); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.filterChain).doFilter(this.request, this.response); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); @@ -394,12 +408,13 @@ public class CsrfFilterTests { @Test public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndRawTokenThenAccessDeniedException() throws Exception { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); XorCsrfTokenRequestAttributeHandler requestHandler = new XorCsrfTokenRequestAttributeHandler(); this.filter.setRequestHandler(requestHandler); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(AccessDeniedException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -424,10 +439,11 @@ public class CsrfFilterTests { requestHandler.setCsrfRequestAttributeName(csrfAttrName); filter.setRequestHandler(requestHandler); CsrfToken expectedCsrfToken = mock(CsrfToken.class); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(expectedCsrfToken, true)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(expectedCsrfToken, true); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); filter.doFilter(this.request, this.response, this.filterChain); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verifyNoInteractions(expectedCsrfToken); CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName); From 33e72b35f92607486c627ae668ff745ca9637277 Mon Sep 17 00:00:00 2001 From: Steve Riesenberg Date: Thu, 19 Jan 2023 10:39:36 -0600 Subject: [PATCH 2/4] Add section for migrating WebSocket support Issue gh-12378 --- .../pages/migration/servlet/exploits.adoc | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/docs/modules/ROOT/pages/migration/servlet/exploits.adoc b/docs/modules/ROOT/pages/migration/servlet/exploits.adoc index 0e41a8a263..379c1f83ac 100644 --- a/docs/modules/ROOT/pages/migration/servlet/exploits.adoc +++ b/docs/modules/ROOT/pages/migration/servlet/exploits.adoc @@ -243,3 +243,65 @@ open fun springSecurity(http: HttpSecurity): SecurityFilterChain { ==== I need to opt out of CSRF BREACH protection for another reason If CSRF BREACH protection does not work for you for another reason, you can opt out using the configuration from the <> section. + +== CSRF BREACH with WebSocket support + +If the steps for <> work for normal HTTP requests and you are using xref:servlet/integrations/websocket.adoc[WebSocket Security] support, then you can also opt into Spring Security 6's default support for BREACH protection of the `CsrfToken` with xref:servlet/integrations/websocket.adoc#websocket-sameorigin-csrf[Stomp headers]. + +.WebSocket Security BREACH Protection +==== +.Java +[source,java,role="primary"] +---- +@Bean +ChannelInterceptor csrfChannelInterceptor() { + return new XorCsrfChannelInterceptor(); +} +---- + +.Kotlin +[source,kotlin,role="secondary"] +---- +@Bean +open fun csrfChannelInterceptor(): ChannelInterceptor { + return XorCsrfChannelInterceptor() +} +---- + +.XML +[source,xml,role="secondary"] +---- + +---- +==== + +If configuring CSRF BREACH protection for WebSocket Security gives you trouble, you can configure the 5.8 default using the following configuration: + +.Configure WebSocket Security with 5.8 default +==== +.Java +[source,java,role="primary"] +---- +@Bean +ChannelInterceptor csrfChannelInterceptor() { + return new CsrfChannelInterceptor(); +} +---- + +.Kotlin +[source,kotlin,role="secondary"] +---- +@Bean +open fun csrfChannelInterceptor(): ChannelInterceptor { + return CsrfChannelInterceptor() +} +---- + +.XML +[source,xml,role="secondary"] +---- + +---- +==== From 13487be268941d000d083c7d11b426ff77f6aa66 Mon Sep 17 00:00:00 2001 From: Steve Riesenberg Date: Thu, 26 Jan 2023 00:03:57 -0600 Subject: [PATCH 3/4] Default to XorCsrfChannelInterceptor in 6.0.x Closes gh-12378 --- .../WebSocketMessageBrokerSecurityConfiguration.java | 4 ++-- .../WebSocketMessageBrokerSecurityConfigurationTests.java | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.java index 44970da453..15deab69b6 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.java @@ -41,7 +41,7 @@ import org.springframework.security.messaging.access.intercept.AuthorizationChan import org.springframework.security.messaging.access.intercept.MessageMatcherDelegatingAuthorizationManager; import org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolver; import org.springframework.security.messaging.context.SecurityContextChannelInterceptor; -import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor; +import org.springframework.security.messaging.web.csrf.XorCsrfChannelInterceptor; import org.springframework.security.messaging.web.socket.server.CsrfTokenHandshakeInterceptor; import org.springframework.util.Assert; import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; @@ -71,7 +71,7 @@ final class WebSocketMessageBrokerSecurityConfiguration private final SecurityContextChannelInterceptor securityContextChannelInterceptor = new SecurityContextChannelInterceptor(); - private ChannelInterceptor csrfChannelInterceptor = new CsrfChannelInterceptor(); + private ChannelInterceptor csrfChannelInterceptor = new XorCsrfChannelInterceptor(); private AuthorizationManager> authorizationManager = ANY_MESSAGE_AUTHENTICATED; diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java index 229c46639a..278974000e 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java @@ -66,7 +66,7 @@ import org.springframework.security.messaging.access.intercept.AuthorizationChan import org.springframework.security.messaging.access.intercept.MessageAuthorizationContext; import org.springframework.security.messaging.access.intercept.MessageMatcherDelegatingAuthorizationManager; import org.springframework.security.messaging.context.SecurityContextChannelInterceptor; -import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor; +import org.springframework.security.messaging.web.csrf.XorCsrfChannelInterceptor; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.DefaultCsrfToken; import org.springframework.security.web.csrf.DeferredCsrfToken; @@ -96,6 +96,8 @@ import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCs public class WebSocketMessageBrokerSecurityConfigurationTests { + private static final String XOR_CSRF_TOKEN_VALUE = "wpe7zB62-NCpcA=="; + AnnotationConfigWebApplicationContext context; Authentication messageUser; @@ -198,7 +200,7 @@ public class WebSocketMessageBrokerSecurityConfigurationTests { MessageChannel messageChannel = clientInboundChannel(); Stream> interceptors = ((AbstractMessageChannel) messageChannel) .getInterceptors().stream().map(ChannelInterceptor::getClass); - assertThat(interceptors).contains(CsrfChannelInterceptor.class); + assertThat(interceptors).contains(XorCsrfChannelInterceptor.class); } @Test @@ -238,7 +240,7 @@ public class WebSocketMessageBrokerSecurityConfigurationTests { public void messagesContextWebSocketUseSecurityContextHolderStrategy() { loadConfig(WebSocketSecurityConfig.class, SecurityContextChangedListenerConfig.class); SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); - headers.setNativeHeader(this.token.getHeaderName(), this.token.getToken()); + headers.setNativeHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE); Message message = message(headers, "/authenticated"); headers.getSessionAttributes().put(CsrfToken.class.getName(), this.token); MessageChannel messageChannel = clientInboundChannel(); From 179428f7da20fd89f2997bd087900de39fda505c Mon Sep 17 00:00:00 2001 From: Steve Riesenberg Date: Thu, 26 Jan 2023 15:34:50 -0600 Subject: [PATCH 4/4] Add section for migrating WebSocket support Issue gh-12378 --- docs/modules/ROOT/nav.adoc | 1 + .../ROOT/pages/migration/servlet/exploits.adoc | 11 +++++++++++ 2 files changed, 12 insertions(+) create mode 100644 docs/modules/ROOT/pages/migration/servlet/exploits.adoc diff --git a/docs/modules/ROOT/nav.adoc b/docs/modules/ROOT/nav.adoc index 0b1059a751..bc274ccb8c 100644 --- a/docs/modules/ROOT/nav.adoc +++ b/docs/modules/ROOT/nav.adoc @@ -5,6 +5,7 @@ * xref:migration/index.adoc[Migrating to 6.0] ** xref:migration/servlet/index.adoc[Servlet Migrations] *** xref:migration/servlet/session-management.adoc[Session Management] +*** xref:migration/servlet/exploits.adoc[Exploit Protection] *** xref:migration/servlet/authentication.adoc[Authentication] *** xref:migration/servlet/authorization.adoc[Authorization] ** xref:migration/reactive.adoc[Reactive Migrations] diff --git a/docs/modules/ROOT/pages/migration/servlet/exploits.adoc b/docs/modules/ROOT/pages/migration/servlet/exploits.adoc new file mode 100644 index 0000000000..35b49b325e --- /dev/null +++ b/docs/modules/ROOT/pages/migration/servlet/exploits.adoc @@ -0,0 +1,11 @@ += Exploit Protection Migrations + +The following steps relate to how to finish migrating exploit protection support. + +== CSRF BREACH with WebSocket support + +In Spring Security 5.8, the default `ChannelInterceptor` for making the `CsrfToken` available with xref:servlet/integrations/websocket.adoc[WebSocket Security] is `CsrfChannelInterceptor`. +`XorCsrfChannelInterceptor` was added to allow opting into CSRF BREACH support. + +In Spring Security 6, `XorCsrfChannelInterceptor` is the default `ChannelInterceptor` for making the `CsrfToken` available. +If you configured the `XorCsrfChannelInterceptor` only for the purpose of updating to 6.0, you can remove it completely.