Add XorCsrfChannelInterceptor

Issue gh-12378
This commit is contained in:
Steve Riesenberg 2023-01-09 15:58:28 -06:00
parent d42405de42
commit c306df9b46
No known key found for this signature in database
GPG Key ID: 5F311AB48A55D521
12 changed files with 501 additions and 58 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -56,6 +56,8 @@ final class WebSocketMessageBrokerSecurityConfiguration
private static final String SIMPLE_URL_HANDLER_MAPPING_BEAN_NAME = "stompWebSocketHandlerMapping";
private static final String CSRF_CHANNEL_INTERCEPTOR_BEAN_NAME = "csrfChannelInterceptor";
private MessageMatcherDelegatingAuthorizationManager b;
private static final AuthorizationManager<Message<?>> ANY_MESSAGE_AUTHENTICATED = MessageMatcherDelegatingAuthorizationManager
@ -66,7 +68,7 @@ final class WebSocketMessageBrokerSecurityConfiguration
private final SecurityContextChannelInterceptor securityContextChannelInterceptor = new SecurityContextChannelInterceptor();
private final ChannelInterceptor csrfChannelInterceptor = new CsrfChannelInterceptor();
private ChannelInterceptor csrfChannelInterceptor = new CsrfChannelInterceptor();
private AuthorizationChannelInterceptor authorizationChannelInterceptor = new AuthorizationChannelInterceptor(
ANY_MESSAGE_AUTHENTICATED);
@ -86,6 +88,12 @@ final class WebSocketMessageBrokerSecurityConfiguration
@Override
public void configureClientInboundChannel(ChannelRegistration registration) {
ChannelInterceptor csrfChannelInterceptor = getBeanOrNull(CSRF_CHANNEL_INTERCEPTOR_BEAN_NAME,
ChannelInterceptor.class);
if (csrfChannelInterceptor != null) {
this.csrfChannelInterceptor = csrfChannelInterceptor;
}
this.authorizationChannelInterceptor
.setAuthorizationEventPublisher(new SpringAuthorizationEventPublisher(this.context));
this.authorizationChannelInterceptor.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -61,6 +61,7 @@ import org.springframework.security.messaging.context.SecurityContextChannelInte
import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.csrf.DeferredCsrfToken;
import org.springframework.security.web.csrf.MissingCsrfTokenException;
import org.springframework.stereotype.Controller;
import org.springframework.test.util.ReflectionTestUtils;
@ -79,6 +80,7 @@ import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSo
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
@ -284,7 +286,7 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
private void assertHandshake(HttpServletRequest request) {
TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class);
assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(this.token);
assertThatCsrfToken(handshakeHandler.attributes.get(CsrfToken.class.getName())).isEqualTo(this.token);
assertThat(handshakeHandler.attributes.get(this.sessionAttr))
.isEqualTo(request.getSession().getAttribute(this.sessionAttr));
}
@ -306,7 +308,7 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket");
request.setRequestURI(mapping + "/289/tpyx6mde/websocket");
request.getSession().setAttribute(this.sessionAttr, "sessionValue");
request.setAttribute(CsrfToken.class.getName(), this.token);
request.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token));
return request;
}

View File

@ -0,0 +1,43 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.config.annotation.web.socket;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DeferredCsrfToken;
/**
* @author Steve Riesenberg
*/
final class TestDeferredCsrfToken implements DeferredCsrfToken {
private final CsrfToken csrfToken;
TestDeferredCsrfToken(CsrfToken csrfToken) {
this.csrfToken = csrfToken;
}
@Override
public CsrfToken get() {
return this.csrfToken;
}
@Override
public boolean isGenerated() {
return false;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -70,6 +70,7 @@ import org.springframework.security.messaging.context.SecurityContextChannelInte
import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.csrf.DeferredCsrfToken;
import org.springframework.security.web.csrf.MissingCsrfTokenException;
import org.springframework.stereotype.Controller;
import org.springframework.test.util.ReflectionTestUtils;
@ -92,6 +93,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.fail;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.verify;
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
public class WebSocketMessageBrokerSecurityConfigurationTests {
@ -367,7 +369,7 @@ public class WebSocketMessageBrokerSecurityConfigurationTests {
private void assertHandshake(HttpServletRequest request) {
TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class);
assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(this.token);
assertThatCsrfToken(handshakeHandler.attributes.get(CsrfToken.class.getName())).isEqualTo(this.token);
assertThat(handshakeHandler.attributes.get(this.sessionAttr))
.isEqualTo(request.getSession().getAttribute(this.sessionAttr));
}
@ -389,7 +391,7 @@ public class WebSocketMessageBrokerSecurityConfigurationTests {
request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket");
request.setRequestURI(mapping + "/289/tpyx6mde/websocket");
request.getSession().setAttribute(this.sessionAttr, "sessionValue");
request.setAttribute(CsrfToken.class.getName(), this.token);
request.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token));
return request;
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -61,6 +61,7 @@ import org.springframework.security.test.context.annotation.SecurityTestExecutio
import org.springframework.security.test.context.support.WithMockUser;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.csrf.DeferredCsrfToken;
import org.springframework.security.web.csrf.InvalidCsrfTokenException;
import org.springframework.stereotype.Controller;
import org.springframework.test.context.junit.jupiter.SpringExtension;
@ -77,6 +78,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.verify;
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
/**
@ -381,12 +383,14 @@ public class WebSocketMessageBrokerConfigTests {
MockMvc mvc = MockMvcBuilders.webAppContextSetup(context).build();
String csrfAttributeName = CsrfToken.class.getName();
String customAttributeName = this.getClass().getName();
MvcResult result = mvc.perform(get("/app").requestAttr(csrfAttributeName, this.token)
.sessionAttr(customAttributeName, "attributeValue")).andReturn();
MvcResult result = mvc.perform(
get("/app").requestAttr(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token))
.sessionAttr(customAttributeName, "attributeValue"))
.andReturn();
CsrfToken handshakeToken = (CsrfToken) this.testHandshakeHandler.attributes.get(csrfAttributeName);
String handshakeValue = (String) this.testHandshakeHandler.attributes.get(customAttributeName);
String sessionValue = (String) result.getRequest().getSession().getAttribute(customAttributeName);
assertThat(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
assertThatCsrfToken(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
assertThat(handshakeValue).isEqualTo(sessionValue)
.withFailMessage("Explicitly listed session variables are not overridden");
}
@ -398,12 +402,13 @@ public class WebSocketMessageBrokerConfigTests {
MockMvc mvc = MockMvcBuilders.webAppContextSetup(context).build();
String csrfAttributeName = CsrfToken.class.getName();
String customAttributeName = this.getClass().getName();
MvcResult result = mvc.perform(get("/app/289/tpyx6mde/websocket").requestAttr(csrfAttributeName, this.token)
MvcResult result = mvc.perform(get("/app/289/tpyx6mde/websocket")
.requestAttr(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token))
.sessionAttr(customAttributeName, "attributeValue")).andReturn();
CsrfToken handshakeToken = (CsrfToken) this.testHandshakeHandler.attributes.get(csrfAttributeName);
String handshakeValue = (String) this.testHandshakeHandler.attributes.get(customAttributeName);
String sessionValue = (String) result.getRequest().getSession().getAttribute(customAttributeName);
assertThat(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
assertThatCsrfToken(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
assertThat(handshakeValue).isEqualTo(sessionValue)
.withFailMessage("Explicitly listed session variables are not overridden");
}
@ -526,6 +531,26 @@ public class WebSocketMessageBrokerConfigTests {
return SecurityContextHolder.getContextHolderStrategy();
}
private static final class TestDeferredCsrfToken implements DeferredCsrfToken {
private final CsrfToken csrfToken;
TestDeferredCsrfToken(CsrfToken csrfToken) {
this.csrfToken = csrfToken;
}
@Override
public CsrfToken get() {
return this.csrfToken;
}
@Override
public boolean isGenerated() {
return false;
}
}
@Controller
static class MessageController {

View File

@ -0,0 +1,86 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.messaging.web.csrf;
import java.security.MessageDigest;
import java.util.Map;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.security.crypto.codec.Utf8;
import org.springframework.security.messaging.util.matcher.MessageMatcher;
import org.springframework.security.messaging.util.matcher.SimpMessageTypeMatcher;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.InvalidCsrfTokenException;
import org.springframework.security.web.csrf.MissingCsrfTokenException;
/**
* {@link ChannelInterceptor} that validates a CSRF token masked by the
* {@link org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler} in
* the header of any {@link SimpMessageType#CONNECT} message.
*
* @author Steve Riesenberg
* @since 5.8
*/
public final class XorCsrfChannelInterceptor implements ChannelInterceptor {
private final MessageMatcher<Object> matcher = new SimpMessageTypeMatcher(SimpMessageType.CONNECT);
@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
if (!this.matcher.matches(message)) {
return message;
}
Map<String, Object> sessionAttributes = SimpMessageHeaderAccessor.getSessionAttributes(message.getHeaders());
CsrfToken expectedToken = (sessionAttributes != null)
? (CsrfToken) sessionAttributes.get(CsrfToken.class.getName()) : null;
if (expectedToken == null) {
throw new MissingCsrfTokenException(null);
}
String actualToken = SimpMessageHeaderAccessor.wrap(message)
.getFirstNativeHeader(expectedToken.getHeaderName());
String actualTokenValue = XorCsrfTokenUtils.getTokenValue(actualToken, expectedToken.getToken());
boolean csrfCheckPassed = equalsConstantTime(expectedToken.getToken(), actualTokenValue);
if (!csrfCheckPassed) {
throw new InvalidCsrfTokenException(expectedToken, actualToken);
}
return message;
}
/**
* Constant time comparison to prevent against timing attacks.
* @param expected
* @param actual
* @return
*/
private static boolean equalsConstantTime(String expected, String actual) {
if (expected == actual) {
return true;
}
if (expected == null || actual == null) {
return false;
}
// Encode after ensure that the string is not null
byte[] expectedBytes = Utf8.encode(expected);
byte[] actualBytes = Utf8.encode(actual);
return MessageDigest.isEqual(expectedBytes, actualBytes);
}
}

View File

@ -0,0 +1,72 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.messaging.web.csrf;
import java.util.Base64;
import org.springframework.security.crypto.codec.Utf8;
/**
* Copied from
* {@link org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler}.
*
* @see <a href=
* "https://github.com/spring-projects/spring-security/issues/12378">gh-12378</a>
*/
final class XorCsrfTokenUtils {
private XorCsrfTokenUtils() {
}
static String getTokenValue(String actualToken, String token) {
byte[] actualBytes;
try {
actualBytes = Base64.getUrlDecoder().decode(actualToken);
}
catch (Exception ex) {
return null;
}
byte[] tokenBytes = Utf8.encode(token);
int tokenSize = tokenBytes.length;
if (actualBytes.length < tokenSize) {
return null;
}
// extract token and random bytes
int randomBytesSize = actualBytes.length - tokenSize;
byte[] xoredCsrf = new byte[tokenSize];
byte[] randomBytes = new byte[randomBytesSize];
System.arraycopy(actualBytes, 0, randomBytes, 0, randomBytesSize);
System.arraycopy(actualBytes, randomBytesSize, xoredCsrf, 0, tokenSize);
byte[] csrfBytes = xorCsrf(randomBytes, xoredCsrf);
return Utf8.decode(csrfBytes);
}
private static byte[] xorCsrf(byte[] randomBytes, byte[] csrfBytes) {
int len = Math.min(randomBytes.length, csrfBytes.length);
byte[] xoredCsrf = new byte[len];
System.arraycopy(csrfBytes, 0, xoredCsrf, 0, csrfBytes.length);
for (int i = 0; i < len; i++) {
xoredCsrf[i] ^= randomBytes[i];
}
return xoredCsrf;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2015 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -24,15 +24,18 @@ import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.csrf.DeferredCsrfToken;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
/**
* Copies a CsrfToken from the HttpServletRequest's attributes to the WebSocket
* attributes. This is used as the expected CsrfToken when validating connection requests
* to ensure only the same origin connects.
* Loads a CsrfToken from the HttpServletRequest and HttpServletResponse to populate the
* WebSocket attributes. This is used as the expected CsrfToken when validating connection
* requests to ensure only the same origin connects.
*
* @author Rob Winch
* @author Steve Riesenberg
* @since 4.0
*/
public final class CsrfTokenHandshakeInterceptor implements HandshakeInterceptor {
@ -41,11 +44,19 @@ public final class CsrfTokenHandshakeInterceptor implements HandshakeInterceptor
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler,
Map<String, Object> attributes) {
HttpServletRequest httpRequest = ((ServletServerHttpRequest) request).getServletRequest();
CsrfToken token = (CsrfToken) httpRequest.getAttribute(CsrfToken.class.getName());
if (token == null) {
DeferredCsrfToken deferredCsrfToken = (DeferredCsrfToken) httpRequest
.getAttribute(DeferredCsrfToken.class.getName());
if (deferredCsrfToken == null) {
return true;
}
attributes.put(CsrfToken.class.getName(), token);
CsrfToken csrfToken = deferredCsrfToken.get();
// Ensure the values of the CsrfToken are copied into a new token so the old token
// is available for garbage collection.
// This is required because the original token could hold a reference to the
// HttpServletRequest/Response of the handshake request.
CsrfToken resolvedCsrfToken = new DefaultCsrfToken(csrfToken.getHeaderName(), csrfToken.getParameterName(),
csrfToken.getToken());
attributes.put(CsrfToken.class.getName(), resolvedCsrfToken);
return true;
}

View File

@ -0,0 +1,148 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.messaging.web.csrf;
import java.util.HashMap;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.csrf.InvalidCsrfTokenException;
import org.springframework.security.web.csrf.MissingCsrfTokenException;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link XorCsrfChannelInterceptor}.
*
* @author Steve Riesenberg
*/
public class XorCsrfChannelInterceptorTests {
private static final String XOR_CSRF_TOKEN_VALUE = "wpe7zB62-NCpcA==";
private static final String INVALID_XOR_CSRF_TOKEN_VALUE = "KneoaygbRZtfHQ==";
private CsrfToken token;
private SimpMessageHeaderAccessor messageHeaders;
private MessageChannel channel;
private XorCsrfChannelInterceptor interceptor;
@BeforeEach
public void setup() {
this.token = new DefaultCsrfToken("header", "param", "token");
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
this.messageHeaders.setSessionAttributes(new HashMap<>());
this.channel = mock(MessageChannel.class);
this.interceptor = new XorCsrfChannelInterceptor();
}
@Test
public void preSendWhenConnectWithValidTokenThenSuccess() {
this.messageHeaders.setNativeHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE);
this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token);
this.interceptor.preSend(message(), this.channel);
}
@Test
public void preSendWhenConnectWithInvalidTokenThenThrowsInvalidCsrfTokenException() {
this.messageHeaders.setNativeHeader(this.token.getHeaderName(), INVALID_XOR_CSRF_TOKEN_VALUE);
this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token);
// @formatter:off
assertThatExceptionOfType(InvalidCsrfTokenException.class)
.isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class)));
// @formatter:on
}
@Test
public void preSendWhenConnectWithNoTokenThenThrowsInvalidCsrfTokenException() {
this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token);
// @formatter:off
assertThatExceptionOfType(InvalidCsrfTokenException.class)
.isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class)));
// @formatter:on
}
@Test
public void preSendWhenConnectWithMissingTokenThenThrowsMissingCsrfTokenException() {
// @formatter:off
assertThatExceptionOfType(MissingCsrfTokenException.class)
.isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class)));
// @formatter:on
}
@Test
public void preSendWhenConnectWithNullSessionAttributesThenThrowsMissingCsrfTokenException() {
this.messageHeaders.setSessionAttributes(null);
// @formatter:off
assertThatExceptionOfType(MissingCsrfTokenException.class)
.isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class)));
// @formatter:on
}
@Test
public void preSendWhenAckThenIgnores() {
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
this.interceptor.preSend(message(), this.channel);
}
@Test
public void preSendWhenDisconnectThenIgnores() {
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT);
this.interceptor.preSend(message(), this.channel);
}
@Test
public void preSendWhenHeartbeatThenIgnores() {
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT);
this.interceptor.preSend(message(), this.channel);
}
@Test
public void preSendWhenMessageThenIgnores() {
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
this.interceptor.preSend(message(), this.channel);
}
@Test
public void preSendWhenOtherThenIgnores() {
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.OTHER);
this.interceptor.preSend(message(), this.channel);
}
@Test
public void preSendWhenUnsubscribeThenIgnores() {
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.UNSUBSCRIBE);
this.interceptor.preSend(message(), this.channel);
}
private Message<String> message() {
return MessageBuilder.withPayload("message").copyHeaders(this.messageHeaders.toMap()).build();
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2016 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -31,6 +31,7 @@ import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.csrf.DeferredCsrfToken;
import org.springframework.web.socket.WebSocketHandler;
import static org.assertj.core.api.Assertions.assertThat;
@ -72,10 +73,38 @@ public class CsrfTokenHandshakeInterceptorTests {
@Test
public void beforeHandshake() throws Exception {
CsrfToken token = new DefaultCsrfToken("header", "param", "token");
this.httpRequest.setAttribute(CsrfToken.class.getName(), token);
this.httpRequest.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(token));
this.interceptor.beforeHandshake(this.request, this.response, this.wsHandler, this.attributes);
assertThat(this.attributes.keySet()).containsOnly(CsrfToken.class.getName());
assertThat(this.attributes.values()).containsOnly(token);
CsrfToken csrfToken = (CsrfToken) this.attributes.get(CsrfToken.class.getName());
assertThat(csrfToken.getHeaderName()).isEqualTo(token.getHeaderName());
assertThat(csrfToken.getParameterName()).isEqualTo(token.getParameterName());
assertThat(csrfToken.getToken()).isEqualTo(token.getToken());
// Ensure the values of the CsrfToken are copied into a new token so the old token
// is available for garbage collection.
// This is required because the original token could hold a reference to the
// HttpServletRequest/Response of the handshake request.
assertThat(csrfToken).isNotSameAs(token);
}
private static final class TestDeferredCsrfToken implements DeferredCsrfToken {
private final CsrfToken csrfToken;
private TestDeferredCsrfToken(CsrfToken csrfToken) {
this.csrfToken = csrfToken;
}
@Override
public CsrfToken get() {
return this.csrfToken;
}
@Override
public boolean isGenerated() {
return false;
}
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -108,6 +108,7 @@ public final class CsrfFilter extends OncePerRequestFilter {
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
DeferredCsrfToken deferredCsrfToken = this.tokenRepository.loadDeferredToken(request, response);
request.setAttribute(DeferredCsrfToken.class.getName(), deferredCsrfToken);
this.requestHandler.handle(request, response, deferredCsrfToken::get);
if (!this.requireCsrfProtectionMatcher.matches(request)) {
if (this.logger.isTraceEnabled()) {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -126,11 +126,12 @@ public class CsrfFilterTests {
@Test
public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
verifyNoMoreInteractions(this.filterChain);
}
@ -138,12 +139,13 @@ public class CsrfFilterTests {
@Test
public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
verifyNoMoreInteractions(this.filterChain);
}
@ -151,12 +153,13 @@ public class CsrfFilterTests {
@Test
public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
verifyNoMoreInteractions(this.filterChain);
}
@ -165,13 +168,14 @@ public class CsrfFilterTests {
public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter()
throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
verifyNoMoreInteractions(this.filterChain);
}
@ -179,11 +183,12 @@ public class CsrfFilterTests {
@Test
public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verify(this.filterChain).doFilter(this.request, this.response);
verifyNoMoreInteractions(this.deniedHandler);
}
@ -191,11 +196,12 @@ public class CsrfFilterTests {
@Test
public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, true));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verify(this.filterChain).doFilter(this.request, this.response);
verifyNoMoreInteractions(this.deniedHandler);
}
@ -203,12 +209,13 @@ public class CsrfFilterTests {
@Test
public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verify(this.filterChain).doFilter(this.request, this.response);
verifyNoMoreInteractions(this.deniedHandler);
}
@ -217,13 +224,14 @@ public class CsrfFilterTests {
public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam()
throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verify(this.filterChain).doFilter(this.request, this.response);
verifyNoMoreInteractions(this.deniedHandler);
}
@ -231,12 +239,13 @@ public class CsrfFilterTests {
@Test
public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verify(this.filterChain).doFilter(this.request, this.response);
verifyNoMoreInteractions(this.deniedHandler);
verify(this.tokenRepository, never()).saveToken(any(CsrfToken.class), any(HttpServletRequest.class),
@ -246,12 +255,13 @@ public class CsrfFilterTests {
@Test
public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException {
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, true));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
// LazyCsrfTokenRepository requires the response as an attribute
assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response);
verify(this.filterChain).doFilter(this.request, this.response);
@ -316,11 +326,12 @@ public class CsrfFilterTests {
this.filter = new CsrfFilter(this.tokenRepository);
this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN);
verifyNoMoreInteractions(this.filterChain);
}
@ -344,22 +355,24 @@ public class CsrfFilterTests {
given(token.getToken()).willReturn(null);
given(token.getHeaderName()).willReturn(this.token.getHeaderName());
given(token.getParameterName()).willReturn(this.token.getParameterName());
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
given(this.requestMatcher.matches(this.request)).willReturn(true);
filter.doFilterInternal(this.request, this.response, this.filterChain);
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
}
@Test
public void doFilterWhenRequestHandlerThenUsed() throws Exception {
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
this.filter = createCsrfFilter(this.tokenRepository);
this.filter.setRequestHandler(requestHandler);
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verify(this.tokenRepository).loadDeferredToken(this.request, this.response);
verify(requestHandler).handle(eq(this.request), eq(this.response), any());
verify(this.filterChain).doFilter(this.request, this.response);
@ -368,14 +381,15 @@ public class CsrfFilterTests {
@Test
public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndValidTokenThenSuccess() throws Exception {
given(this.requestMatcher.matches(this.request)).willReturn(false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
XorCsrfTokenRequestAttributeHandler requestHandler = new XorCsrfTokenRequestAttributeHandler();
requestHandler.setCsrfRequestAttributeName(this.token.getParameterName());
this.filter.setRequestHandler(requestHandler);
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
assertThat(this.request.getAttribute(this.token.getParameterName())).isNotNull();
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verify(this.filterChain).doFilter(this.request, this.response);
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
@ -394,12 +408,13 @@ public class CsrfFilterTests {
@Test
public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndRawTokenThenAccessDeniedException() throws Exception {
given(this.requestMatcher.matches(this.request)).willReturn(true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(this.token, false));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
XorCsrfTokenRequestAttributeHandler requestHandler = new XorCsrfTokenRequestAttributeHandler();
this.filter.setRequestHandler(requestHandler);
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(AccessDeniedException.class));
verifyNoMoreInteractions(this.filterChain);
}
@ -424,10 +439,11 @@ public class CsrfFilterTests {
requestHandler.setCsrfRequestAttributeName(csrfAttrName);
filter.setRequestHandler(requestHandler);
CsrfToken expectedCsrfToken = mock(CsrfToken.class);
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
.willReturn(new TestDeferredCsrfToken(expectedCsrfToken, true));
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(expectedCsrfToken, true);
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
filter.doFilter(this.request, this.response, this.filterChain);
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
verifyNoInteractions(expectedCsrfToken);
CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName);