Merge branch '6.0.x'
This commit is contained in:
commit
6abbdd3654
|
@ -1,5 +1,5 @@
|
||||||
/*
|
/*
|
||||||
* Copyright 2002-2022 the original author or authors.
|
* Copyright 2002-2023 the original author or authors.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -41,7 +41,7 @@ import org.springframework.security.messaging.access.intercept.AuthorizationChan
|
||||||
import org.springframework.security.messaging.access.intercept.MessageMatcherDelegatingAuthorizationManager;
|
import org.springframework.security.messaging.access.intercept.MessageMatcherDelegatingAuthorizationManager;
|
||||||
import org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolver;
|
import org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolver;
|
||||||
import org.springframework.security.messaging.context.SecurityContextChannelInterceptor;
|
import org.springframework.security.messaging.context.SecurityContextChannelInterceptor;
|
||||||
import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
|
import org.springframework.security.messaging.web.csrf.XorCsrfChannelInterceptor;
|
||||||
import org.springframework.security.messaging.web.socket.server.CsrfTokenHandshakeInterceptor;
|
import org.springframework.security.messaging.web.socket.server.CsrfTokenHandshakeInterceptor;
|
||||||
import org.springframework.util.Assert;
|
import org.springframework.util.Assert;
|
||||||
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
|
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
|
||||||
|
@ -59,6 +59,8 @@ final class WebSocketMessageBrokerSecurityConfiguration
|
||||||
|
|
||||||
private static final String SIMPLE_URL_HANDLER_MAPPING_BEAN_NAME = "stompWebSocketHandlerMapping";
|
private static final String SIMPLE_URL_HANDLER_MAPPING_BEAN_NAME = "stompWebSocketHandlerMapping";
|
||||||
|
|
||||||
|
private static final String CSRF_CHANNEL_INTERCEPTOR_BEAN_NAME = "csrfChannelInterceptor";
|
||||||
|
|
||||||
private MessageMatcherDelegatingAuthorizationManager b;
|
private MessageMatcherDelegatingAuthorizationManager b;
|
||||||
|
|
||||||
private static final AuthorizationManager<Message<?>> ANY_MESSAGE_AUTHENTICATED = MessageMatcherDelegatingAuthorizationManager
|
private static final AuthorizationManager<Message<?>> ANY_MESSAGE_AUTHENTICATED = MessageMatcherDelegatingAuthorizationManager
|
||||||
|
@ -69,7 +71,7 @@ final class WebSocketMessageBrokerSecurityConfiguration
|
||||||
|
|
||||||
private final SecurityContextChannelInterceptor securityContextChannelInterceptor = new SecurityContextChannelInterceptor();
|
private final SecurityContextChannelInterceptor securityContextChannelInterceptor = new SecurityContextChannelInterceptor();
|
||||||
|
|
||||||
private final ChannelInterceptor csrfChannelInterceptor = new CsrfChannelInterceptor();
|
private ChannelInterceptor csrfChannelInterceptor = new XorCsrfChannelInterceptor();
|
||||||
|
|
||||||
private AuthorizationManager<Message<?>> authorizationManager = ANY_MESSAGE_AUTHENTICATED;
|
private AuthorizationManager<Message<?>> authorizationManager = ANY_MESSAGE_AUTHENTICATED;
|
||||||
|
|
||||||
|
@ -90,6 +92,12 @@ final class WebSocketMessageBrokerSecurityConfiguration
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void configureClientInboundChannel(ChannelRegistration registration) {
|
public void configureClientInboundChannel(ChannelRegistration registration) {
|
||||||
|
ChannelInterceptor csrfChannelInterceptor = getBeanOrNull(CSRF_CHANNEL_INTERCEPTOR_BEAN_NAME,
|
||||||
|
ChannelInterceptor.class);
|
||||||
|
if (csrfChannelInterceptor != null) {
|
||||||
|
this.csrfChannelInterceptor = csrfChannelInterceptor;
|
||||||
|
}
|
||||||
|
|
||||||
AuthorizationManager<Message<?>> manager = this.authorizationManager;
|
AuthorizationManager<Message<?>> manager = this.authorizationManager;
|
||||||
if (!this.observationRegistry.isNoop()) {
|
if (!this.observationRegistry.isNoop()) {
|
||||||
manager = new ObservationAuthorizationManager<>(this.observationRegistry, manager);
|
manager = new ObservationAuthorizationManager<>(this.observationRegistry, manager);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/*
|
/*
|
||||||
* Copyright 2002-2019 the original author or authors.
|
* Copyright 2002-2023 the original author or authors.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -60,6 +60,7 @@ import org.springframework.security.messaging.context.SecurityContextChannelInte
|
||||||
import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
|
import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
|
||||||
import org.springframework.security.web.csrf.CsrfToken;
|
import org.springframework.security.web.csrf.CsrfToken;
|
||||||
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
||||||
|
import org.springframework.security.web.csrf.DeferredCsrfToken;
|
||||||
import org.springframework.security.web.csrf.MissingCsrfTokenException;
|
import org.springframework.security.web.csrf.MissingCsrfTokenException;
|
||||||
import org.springframework.stereotype.Controller;
|
import org.springframework.stereotype.Controller;
|
||||||
import org.springframework.test.util.ReflectionTestUtils;
|
import org.springframework.test.util.ReflectionTestUtils;
|
||||||
|
@ -78,6 +79,7 @@ import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSo
|
||||||
|
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
||||||
|
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
|
||||||
|
|
||||||
public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
|
public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
|
||||||
|
|
||||||
|
@ -283,7 +285,7 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
|
||||||
|
|
||||||
private void assertHandshake(HttpServletRequest request) {
|
private void assertHandshake(HttpServletRequest request) {
|
||||||
TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class);
|
TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class);
|
||||||
assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(this.token);
|
assertThatCsrfToken(handshakeHandler.attributes.get(CsrfToken.class.getName())).isEqualTo(this.token);
|
||||||
assertThat(handshakeHandler.attributes.get(this.sessionAttr))
|
assertThat(handshakeHandler.attributes.get(this.sessionAttr))
|
||||||
.isEqualTo(request.getSession().getAttribute(this.sessionAttr));
|
.isEqualTo(request.getSession().getAttribute(this.sessionAttr));
|
||||||
}
|
}
|
||||||
|
@ -305,7 +307,7 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
|
||||||
request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket");
|
request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket");
|
||||||
request.setRequestURI(mapping + "/289/tpyx6mde/websocket");
|
request.setRequestURI(mapping + "/289/tpyx6mde/websocket");
|
||||||
request.getSession().setAttribute(this.sessionAttr, "sessionValue");
|
request.getSession().setAttribute(this.sessionAttr, "sessionValue");
|
||||||
request.setAttribute(CsrfToken.class.getName(), this.token);
|
request.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token));
|
||||||
return request;
|
return request;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2002-2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.springframework.security.config.annotation.web.socket;
|
||||||
|
|
||||||
|
import org.springframework.security.web.csrf.CsrfToken;
|
||||||
|
import org.springframework.security.web.csrf.DeferredCsrfToken;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @author Steve Riesenberg
|
||||||
|
*/
|
||||||
|
final class TestDeferredCsrfToken implements DeferredCsrfToken {
|
||||||
|
|
||||||
|
private final CsrfToken csrfToken;
|
||||||
|
|
||||||
|
TestDeferredCsrfToken(CsrfToken csrfToken) {
|
||||||
|
this.csrfToken = csrfToken;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public CsrfToken get() {
|
||||||
|
return this.csrfToken;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isGenerated() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -1,5 +1,5 @@
|
||||||
/*
|
/*
|
||||||
* Copyright 2002-2022 the original author or authors.
|
* Copyright 2002-2023 the original author or authors.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -66,9 +66,10 @@ import org.springframework.security.messaging.access.intercept.AuthorizationChan
|
||||||
import org.springframework.security.messaging.access.intercept.MessageAuthorizationContext;
|
import org.springframework.security.messaging.access.intercept.MessageAuthorizationContext;
|
||||||
import org.springframework.security.messaging.access.intercept.MessageMatcherDelegatingAuthorizationManager;
|
import org.springframework.security.messaging.access.intercept.MessageMatcherDelegatingAuthorizationManager;
|
||||||
import org.springframework.security.messaging.context.SecurityContextChannelInterceptor;
|
import org.springframework.security.messaging.context.SecurityContextChannelInterceptor;
|
||||||
import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
|
import org.springframework.security.messaging.web.csrf.XorCsrfChannelInterceptor;
|
||||||
import org.springframework.security.web.csrf.CsrfToken;
|
import org.springframework.security.web.csrf.CsrfToken;
|
||||||
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
||||||
|
import org.springframework.security.web.csrf.DeferredCsrfToken;
|
||||||
import org.springframework.security.web.csrf.MissingCsrfTokenException;
|
import org.springframework.security.web.csrf.MissingCsrfTokenException;
|
||||||
import org.springframework.stereotype.Controller;
|
import org.springframework.stereotype.Controller;
|
||||||
import org.springframework.test.util.ReflectionTestUtils;
|
import org.springframework.test.util.ReflectionTestUtils;
|
||||||
|
@ -91,9 +92,12 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
||||||
import static org.assertj.core.api.Assertions.fail;
|
import static org.assertj.core.api.Assertions.fail;
|
||||||
import static org.mockito.Mockito.atLeastOnce;
|
import static org.mockito.Mockito.atLeastOnce;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
|
||||||
|
|
||||||
public class WebSocketMessageBrokerSecurityConfigurationTests {
|
public class WebSocketMessageBrokerSecurityConfigurationTests {
|
||||||
|
|
||||||
|
private static final String XOR_CSRF_TOKEN_VALUE = "wpe7zB62-NCpcA==";
|
||||||
|
|
||||||
AnnotationConfigWebApplicationContext context;
|
AnnotationConfigWebApplicationContext context;
|
||||||
|
|
||||||
Authentication messageUser;
|
Authentication messageUser;
|
||||||
|
@ -196,7 +200,7 @@ public class WebSocketMessageBrokerSecurityConfigurationTests {
|
||||||
MessageChannel messageChannel = clientInboundChannel();
|
MessageChannel messageChannel = clientInboundChannel();
|
||||||
Stream<Class<? extends ChannelInterceptor>> interceptors = ((AbstractMessageChannel) messageChannel)
|
Stream<Class<? extends ChannelInterceptor>> interceptors = ((AbstractMessageChannel) messageChannel)
|
||||||
.getInterceptors().stream().map(ChannelInterceptor::getClass);
|
.getInterceptors().stream().map(ChannelInterceptor::getClass);
|
||||||
assertThat(interceptors).contains(CsrfChannelInterceptor.class);
|
assertThat(interceptors).contains(XorCsrfChannelInterceptor.class);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -236,7 +240,7 @@ public class WebSocketMessageBrokerSecurityConfigurationTests {
|
||||||
public void messagesContextWebSocketUseSecurityContextHolderStrategy() {
|
public void messagesContextWebSocketUseSecurityContextHolderStrategy() {
|
||||||
loadConfig(WebSocketSecurityConfig.class, SecurityContextChangedListenerConfig.class);
|
loadConfig(WebSocketSecurityConfig.class, SecurityContextChangedListenerConfig.class);
|
||||||
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
|
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
|
||||||
headers.setNativeHeader(this.token.getHeaderName(), this.token.getToken());
|
headers.setNativeHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE);
|
||||||
Message<?> message = message(headers, "/authenticated");
|
Message<?> message = message(headers, "/authenticated");
|
||||||
headers.getSessionAttributes().put(CsrfToken.class.getName(), this.token);
|
headers.getSessionAttributes().put(CsrfToken.class.getName(), this.token);
|
||||||
MessageChannel messageChannel = clientInboundChannel();
|
MessageChannel messageChannel = clientInboundChannel();
|
||||||
|
@ -366,7 +370,7 @@ public class WebSocketMessageBrokerSecurityConfigurationTests {
|
||||||
|
|
||||||
private void assertHandshake(HttpServletRequest request) {
|
private void assertHandshake(HttpServletRequest request) {
|
||||||
TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class);
|
TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class);
|
||||||
assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(this.token);
|
assertThatCsrfToken(handshakeHandler.attributes.get(CsrfToken.class.getName())).isEqualTo(this.token);
|
||||||
assertThat(handshakeHandler.attributes.get(this.sessionAttr))
|
assertThat(handshakeHandler.attributes.get(this.sessionAttr))
|
||||||
.isEqualTo(request.getSession().getAttribute(this.sessionAttr));
|
.isEqualTo(request.getSession().getAttribute(this.sessionAttr));
|
||||||
}
|
}
|
||||||
|
@ -388,7 +392,7 @@ public class WebSocketMessageBrokerSecurityConfigurationTests {
|
||||||
request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket");
|
request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket");
|
||||||
request.setRequestURI(mapping + "/289/tpyx6mde/websocket");
|
request.setRequestURI(mapping + "/289/tpyx6mde/websocket");
|
||||||
request.getSession().setAttribute(this.sessionAttr, "sessionValue");
|
request.getSession().setAttribute(this.sessionAttr, "sessionValue");
|
||||||
request.setAttribute(CsrfToken.class.getName(), this.token);
|
request.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token));
|
||||||
return request;
|
return request;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/*
|
/*
|
||||||
* Copyright 2002-2022 the original author or authors.
|
* Copyright 2002-2023 the original author or authors.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -61,6 +61,7 @@ import org.springframework.security.test.context.annotation.SecurityTestExecutio
|
||||||
import org.springframework.security.test.context.support.WithMockUser;
|
import org.springframework.security.test.context.support.WithMockUser;
|
||||||
import org.springframework.security.web.csrf.CsrfToken;
|
import org.springframework.security.web.csrf.CsrfToken;
|
||||||
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
||||||
|
import org.springframework.security.web.csrf.DeferredCsrfToken;
|
||||||
import org.springframework.security.web.csrf.InvalidCsrfTokenException;
|
import org.springframework.security.web.csrf.InvalidCsrfTokenException;
|
||||||
import org.springframework.stereotype.Controller;
|
import org.springframework.stereotype.Controller;
|
||||||
import org.springframework.test.context.junit.jupiter.SpringExtension;
|
import org.springframework.test.context.junit.jupiter.SpringExtension;
|
||||||
|
@ -77,6 +78,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
||||||
import static org.mockito.ArgumentMatchers.any;
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
import static org.mockito.BDDMockito.given;
|
import static org.mockito.BDDMockito.given;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
|
||||||
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
|
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -381,12 +383,14 @@ public class WebSocketMessageBrokerConfigTests {
|
||||||
MockMvc mvc = MockMvcBuilders.webAppContextSetup(context).build();
|
MockMvc mvc = MockMvcBuilders.webAppContextSetup(context).build();
|
||||||
String csrfAttributeName = CsrfToken.class.getName();
|
String csrfAttributeName = CsrfToken.class.getName();
|
||||||
String customAttributeName = this.getClass().getName();
|
String customAttributeName = this.getClass().getName();
|
||||||
MvcResult result = mvc.perform(get("/app").requestAttr(csrfAttributeName, this.token)
|
MvcResult result = mvc.perform(
|
||||||
.sessionAttr(customAttributeName, "attributeValue")).andReturn();
|
get("/app").requestAttr(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token))
|
||||||
|
.sessionAttr(customAttributeName, "attributeValue"))
|
||||||
|
.andReturn();
|
||||||
CsrfToken handshakeToken = (CsrfToken) this.testHandshakeHandler.attributes.get(csrfAttributeName);
|
CsrfToken handshakeToken = (CsrfToken) this.testHandshakeHandler.attributes.get(csrfAttributeName);
|
||||||
String handshakeValue = (String) this.testHandshakeHandler.attributes.get(customAttributeName);
|
String handshakeValue = (String) this.testHandshakeHandler.attributes.get(customAttributeName);
|
||||||
String sessionValue = (String) result.getRequest().getSession().getAttribute(customAttributeName);
|
String sessionValue = (String) result.getRequest().getSession().getAttribute(customAttributeName);
|
||||||
assertThat(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
|
assertThatCsrfToken(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
|
||||||
assertThat(handshakeValue).isEqualTo(sessionValue)
|
assertThat(handshakeValue).isEqualTo(sessionValue)
|
||||||
.withFailMessage("Explicitly listed session variables are not overridden");
|
.withFailMessage("Explicitly listed session variables are not overridden");
|
||||||
}
|
}
|
||||||
|
@ -398,12 +402,13 @@ public class WebSocketMessageBrokerConfigTests {
|
||||||
MockMvc mvc = MockMvcBuilders.webAppContextSetup(context).build();
|
MockMvc mvc = MockMvcBuilders.webAppContextSetup(context).build();
|
||||||
String csrfAttributeName = CsrfToken.class.getName();
|
String csrfAttributeName = CsrfToken.class.getName();
|
||||||
String customAttributeName = this.getClass().getName();
|
String customAttributeName = this.getClass().getName();
|
||||||
MvcResult result = mvc.perform(get("/app/289/tpyx6mde/websocket").requestAttr(csrfAttributeName, this.token)
|
MvcResult result = mvc.perform(get("/app/289/tpyx6mde/websocket")
|
||||||
|
.requestAttr(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token))
|
||||||
.sessionAttr(customAttributeName, "attributeValue")).andReturn();
|
.sessionAttr(customAttributeName, "attributeValue")).andReturn();
|
||||||
CsrfToken handshakeToken = (CsrfToken) this.testHandshakeHandler.attributes.get(csrfAttributeName);
|
CsrfToken handshakeToken = (CsrfToken) this.testHandshakeHandler.attributes.get(csrfAttributeName);
|
||||||
String handshakeValue = (String) this.testHandshakeHandler.attributes.get(customAttributeName);
|
String handshakeValue = (String) this.testHandshakeHandler.attributes.get(customAttributeName);
|
||||||
String sessionValue = (String) result.getRequest().getSession().getAttribute(customAttributeName);
|
String sessionValue = (String) result.getRequest().getSession().getAttribute(customAttributeName);
|
||||||
assertThat(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
|
assertThatCsrfToken(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
|
||||||
assertThat(handshakeValue).isEqualTo(sessionValue)
|
assertThat(handshakeValue).isEqualTo(sessionValue)
|
||||||
.withFailMessage("Explicitly listed session variables are not overridden");
|
.withFailMessage("Explicitly listed session variables are not overridden");
|
||||||
}
|
}
|
||||||
|
@ -526,6 +531,26 @@ public class WebSocketMessageBrokerConfigTests {
|
||||||
return SecurityContextHolder.getContextHolderStrategy();
|
return SecurityContextHolder.getContextHolderStrategy();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static final class TestDeferredCsrfToken implements DeferredCsrfToken {
|
||||||
|
|
||||||
|
private final CsrfToken csrfToken;
|
||||||
|
|
||||||
|
TestDeferredCsrfToken(CsrfToken csrfToken) {
|
||||||
|
this.csrfToken = csrfToken;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public CsrfToken get() {
|
||||||
|
return this.csrfToken;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isGenerated() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
@Controller
|
@Controller
|
||||||
static class MessageController {
|
static class MessageController {
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
* xref:migration/index.adoc[Migrating to 6.0]
|
* xref:migration/index.adoc[Migrating to 6.0]
|
||||||
** xref:migration/servlet/index.adoc[Servlet Migrations]
|
** xref:migration/servlet/index.adoc[Servlet Migrations]
|
||||||
*** xref:migration/servlet/session-management.adoc[Session Management]
|
*** xref:migration/servlet/session-management.adoc[Session Management]
|
||||||
|
*** xref:migration/servlet/exploits.adoc[Exploit Protection]
|
||||||
*** xref:migration/servlet/authentication.adoc[Authentication]
|
*** xref:migration/servlet/authentication.adoc[Authentication]
|
||||||
*** xref:migration/servlet/authorization.adoc[Authorization]
|
*** xref:migration/servlet/authorization.adoc[Authorization]
|
||||||
** xref:migration/reactive.adoc[Reactive Migrations]
|
** xref:migration/reactive.adoc[Reactive Migrations]
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
= Exploit Protection Migrations
|
||||||
|
|
||||||
|
The following steps relate to how to finish migrating exploit protection support.
|
||||||
|
|
||||||
|
== CSRF BREACH with WebSocket support
|
||||||
|
|
||||||
|
In Spring Security 5.8, the default `ChannelInterceptor` for making the `CsrfToken` available with xref:servlet/integrations/websocket.adoc[WebSocket Security] is `CsrfChannelInterceptor`.
|
||||||
|
`XorCsrfChannelInterceptor` was added to allow opting into CSRF BREACH support.
|
||||||
|
|
||||||
|
In Spring Security 6, `XorCsrfChannelInterceptor` is the default `ChannelInterceptor` for making the `CsrfToken` available.
|
||||||
|
If you configured the `XorCsrfChannelInterceptor` only for the purpose of updating to 6.0, you can remove it completely.
|
|
@ -0,0 +1,86 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2002-2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.springframework.security.messaging.web.csrf;
|
||||||
|
|
||||||
|
import java.security.MessageDigest;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import org.springframework.messaging.Message;
|
||||||
|
import org.springframework.messaging.MessageChannel;
|
||||||
|
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
|
||||||
|
import org.springframework.messaging.simp.SimpMessageType;
|
||||||
|
import org.springframework.messaging.support.ChannelInterceptor;
|
||||||
|
import org.springframework.security.crypto.codec.Utf8;
|
||||||
|
import org.springframework.security.messaging.util.matcher.MessageMatcher;
|
||||||
|
import org.springframework.security.messaging.util.matcher.SimpMessageTypeMatcher;
|
||||||
|
import org.springframework.security.web.csrf.CsrfToken;
|
||||||
|
import org.springframework.security.web.csrf.InvalidCsrfTokenException;
|
||||||
|
import org.springframework.security.web.csrf.MissingCsrfTokenException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@link ChannelInterceptor} that validates a CSRF token masked by the
|
||||||
|
* {@link org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler} in
|
||||||
|
* the header of any {@link SimpMessageType#CONNECT} message.
|
||||||
|
*
|
||||||
|
* @author Steve Riesenberg
|
||||||
|
* @since 5.8
|
||||||
|
*/
|
||||||
|
public final class XorCsrfChannelInterceptor implements ChannelInterceptor {
|
||||||
|
|
||||||
|
private final MessageMatcher<Object> matcher = new SimpMessageTypeMatcher(SimpMessageType.CONNECT);
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Message<?> preSend(Message<?> message, MessageChannel channel) {
|
||||||
|
if (!this.matcher.matches(message)) {
|
||||||
|
return message;
|
||||||
|
}
|
||||||
|
Map<String, Object> sessionAttributes = SimpMessageHeaderAccessor.getSessionAttributes(message.getHeaders());
|
||||||
|
CsrfToken expectedToken = (sessionAttributes != null)
|
||||||
|
? (CsrfToken) sessionAttributes.get(CsrfToken.class.getName()) : null;
|
||||||
|
if (expectedToken == null) {
|
||||||
|
throw new MissingCsrfTokenException(null);
|
||||||
|
}
|
||||||
|
String actualToken = SimpMessageHeaderAccessor.wrap(message)
|
||||||
|
.getFirstNativeHeader(expectedToken.getHeaderName());
|
||||||
|
String actualTokenValue = XorCsrfTokenUtils.getTokenValue(actualToken, expectedToken.getToken());
|
||||||
|
boolean csrfCheckPassed = equalsConstantTime(expectedToken.getToken(), actualTokenValue);
|
||||||
|
if (!csrfCheckPassed) {
|
||||||
|
throw new InvalidCsrfTokenException(expectedToken, actualToken);
|
||||||
|
}
|
||||||
|
return message;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constant time comparison to prevent against timing attacks.
|
||||||
|
* @param expected
|
||||||
|
* @param actual
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
private static boolean equalsConstantTime(String expected, String actual) {
|
||||||
|
if (expected == actual) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (expected == null || actual == null) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Encode after ensure that the string is not null
|
||||||
|
byte[] expectedBytes = Utf8.encode(expected);
|
||||||
|
byte[] actualBytes = Utf8.encode(actual);
|
||||||
|
return MessageDigest.isEqual(expectedBytes, actualBytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,72 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2002-2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.springframework.security.messaging.web.csrf;
|
||||||
|
|
||||||
|
import java.util.Base64;
|
||||||
|
|
||||||
|
import org.springframework.security.crypto.codec.Utf8;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Copied from
|
||||||
|
* {@link org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler}.
|
||||||
|
*
|
||||||
|
* @see <a href=
|
||||||
|
* "https://github.com/spring-projects/spring-security/issues/12378">gh-12378</a>
|
||||||
|
*/
|
||||||
|
final class XorCsrfTokenUtils {
|
||||||
|
|
||||||
|
private XorCsrfTokenUtils() {
|
||||||
|
}
|
||||||
|
|
||||||
|
static String getTokenValue(String actualToken, String token) {
|
||||||
|
byte[] actualBytes;
|
||||||
|
try {
|
||||||
|
actualBytes = Base64.getUrlDecoder().decode(actualToken);
|
||||||
|
}
|
||||||
|
catch (Exception ex) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
byte[] tokenBytes = Utf8.encode(token);
|
||||||
|
int tokenSize = tokenBytes.length;
|
||||||
|
if (actualBytes.length < tokenSize) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract token and random bytes
|
||||||
|
int randomBytesSize = actualBytes.length - tokenSize;
|
||||||
|
byte[] xoredCsrf = new byte[tokenSize];
|
||||||
|
byte[] randomBytes = new byte[randomBytesSize];
|
||||||
|
|
||||||
|
System.arraycopy(actualBytes, 0, randomBytes, 0, randomBytesSize);
|
||||||
|
System.arraycopy(actualBytes, randomBytesSize, xoredCsrf, 0, tokenSize);
|
||||||
|
|
||||||
|
byte[] csrfBytes = xorCsrf(randomBytes, xoredCsrf);
|
||||||
|
return Utf8.decode(csrfBytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static byte[] xorCsrf(byte[] randomBytes, byte[] csrfBytes) {
|
||||||
|
int len = Math.min(randomBytes.length, csrfBytes.length);
|
||||||
|
byte[] xoredCsrf = new byte[len];
|
||||||
|
System.arraycopy(csrfBytes, 0, xoredCsrf, 0, csrfBytes.length);
|
||||||
|
for (int i = 0; i < len; i++) {
|
||||||
|
xoredCsrf[i] ^= randomBytes[i];
|
||||||
|
}
|
||||||
|
return xoredCsrf;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -1,5 +1,5 @@
|
||||||
/*
|
/*
|
||||||
* Copyright 2002-2015 the original author or authors.
|
* Copyright 2002-2023 the original author or authors.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -24,15 +24,18 @@ import org.springframework.http.server.ServerHttpRequest;
|
||||||
import org.springframework.http.server.ServerHttpResponse;
|
import org.springframework.http.server.ServerHttpResponse;
|
||||||
import org.springframework.http.server.ServletServerHttpRequest;
|
import org.springframework.http.server.ServletServerHttpRequest;
|
||||||
import org.springframework.security.web.csrf.CsrfToken;
|
import org.springframework.security.web.csrf.CsrfToken;
|
||||||
|
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
||||||
|
import org.springframework.security.web.csrf.DeferredCsrfToken;
|
||||||
import org.springframework.web.socket.WebSocketHandler;
|
import org.springframework.web.socket.WebSocketHandler;
|
||||||
import org.springframework.web.socket.server.HandshakeInterceptor;
|
import org.springframework.web.socket.server.HandshakeInterceptor;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Copies a CsrfToken from the HttpServletRequest's attributes to the WebSocket
|
* Loads a CsrfToken from the HttpServletRequest and HttpServletResponse to populate the
|
||||||
* attributes. This is used as the expected CsrfToken when validating connection requests
|
* WebSocket attributes. This is used as the expected CsrfToken when validating connection
|
||||||
* to ensure only the same origin connects.
|
* requests to ensure only the same origin connects.
|
||||||
*
|
*
|
||||||
* @author Rob Winch
|
* @author Rob Winch
|
||||||
|
* @author Steve Riesenberg
|
||||||
* @since 4.0
|
* @since 4.0
|
||||||
*/
|
*/
|
||||||
public final class CsrfTokenHandshakeInterceptor implements HandshakeInterceptor {
|
public final class CsrfTokenHandshakeInterceptor implements HandshakeInterceptor {
|
||||||
|
@ -41,11 +44,19 @@ public final class CsrfTokenHandshakeInterceptor implements HandshakeInterceptor
|
||||||
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler,
|
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler,
|
||||||
Map<String, Object> attributes) {
|
Map<String, Object> attributes) {
|
||||||
HttpServletRequest httpRequest = ((ServletServerHttpRequest) request).getServletRequest();
|
HttpServletRequest httpRequest = ((ServletServerHttpRequest) request).getServletRequest();
|
||||||
CsrfToken token = (CsrfToken) httpRequest.getAttribute(CsrfToken.class.getName());
|
DeferredCsrfToken deferredCsrfToken = (DeferredCsrfToken) httpRequest
|
||||||
if (token == null) {
|
.getAttribute(DeferredCsrfToken.class.getName());
|
||||||
|
if (deferredCsrfToken == null) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
attributes.put(CsrfToken.class.getName(), token);
|
CsrfToken csrfToken = deferredCsrfToken.get();
|
||||||
|
// Ensure the values of the CsrfToken are copied into a new token so the old token
|
||||||
|
// is available for garbage collection.
|
||||||
|
// This is required because the original token could hold a reference to the
|
||||||
|
// HttpServletRequest/Response of the handshake request.
|
||||||
|
CsrfToken resolvedCsrfToken = new DefaultCsrfToken(csrfToken.getHeaderName(), csrfToken.getParameterName(),
|
||||||
|
csrfToken.getToken());
|
||||||
|
attributes.put(CsrfToken.class.getName(), resolvedCsrfToken);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,148 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2002-2023 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.springframework.security.messaging.web.csrf;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import org.springframework.messaging.Message;
|
||||||
|
import org.springframework.messaging.MessageChannel;
|
||||||
|
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
|
||||||
|
import org.springframework.messaging.simp.SimpMessageType;
|
||||||
|
import org.springframework.messaging.support.MessageBuilder;
|
||||||
|
import org.springframework.security.web.csrf.CsrfToken;
|
||||||
|
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
||||||
|
import org.springframework.security.web.csrf.InvalidCsrfTokenException;
|
||||||
|
import org.springframework.security.web.csrf.MissingCsrfTokenException;
|
||||||
|
|
||||||
|
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tests for {@link XorCsrfChannelInterceptor}.
|
||||||
|
*
|
||||||
|
* @author Steve Riesenberg
|
||||||
|
*/
|
||||||
|
public class XorCsrfChannelInterceptorTests {
|
||||||
|
|
||||||
|
private static final String XOR_CSRF_TOKEN_VALUE = "wpe7zB62-NCpcA==";
|
||||||
|
|
||||||
|
private static final String INVALID_XOR_CSRF_TOKEN_VALUE = "KneoaygbRZtfHQ==";
|
||||||
|
|
||||||
|
private CsrfToken token;
|
||||||
|
|
||||||
|
private SimpMessageHeaderAccessor messageHeaders;
|
||||||
|
|
||||||
|
private MessageChannel channel;
|
||||||
|
|
||||||
|
private XorCsrfChannelInterceptor interceptor;
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
public void setup() {
|
||||||
|
this.token = new DefaultCsrfToken("header", "param", "token");
|
||||||
|
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
|
||||||
|
this.messageHeaders.setSessionAttributes(new HashMap<>());
|
||||||
|
this.channel = mock(MessageChannel.class);
|
||||||
|
this.interceptor = new XorCsrfChannelInterceptor();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void preSendWhenConnectWithValidTokenThenSuccess() {
|
||||||
|
this.messageHeaders.setNativeHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE);
|
||||||
|
this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token);
|
||||||
|
this.interceptor.preSend(message(), this.channel);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void preSendWhenConnectWithInvalidTokenThenThrowsInvalidCsrfTokenException() {
|
||||||
|
this.messageHeaders.setNativeHeader(this.token.getHeaderName(), INVALID_XOR_CSRF_TOKEN_VALUE);
|
||||||
|
this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token);
|
||||||
|
// @formatter:off
|
||||||
|
assertThatExceptionOfType(InvalidCsrfTokenException.class)
|
||||||
|
.isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class)));
|
||||||
|
// @formatter:on
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void preSendWhenConnectWithNoTokenThenThrowsInvalidCsrfTokenException() {
|
||||||
|
this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token);
|
||||||
|
// @formatter:off
|
||||||
|
assertThatExceptionOfType(InvalidCsrfTokenException.class)
|
||||||
|
.isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class)));
|
||||||
|
// @formatter:on
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void preSendWhenConnectWithMissingTokenThenThrowsMissingCsrfTokenException() {
|
||||||
|
// @formatter:off
|
||||||
|
assertThatExceptionOfType(MissingCsrfTokenException.class)
|
||||||
|
.isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class)));
|
||||||
|
// @formatter:on
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void preSendWhenConnectWithNullSessionAttributesThenThrowsMissingCsrfTokenException() {
|
||||||
|
this.messageHeaders.setSessionAttributes(null);
|
||||||
|
// @formatter:off
|
||||||
|
assertThatExceptionOfType(MissingCsrfTokenException.class)
|
||||||
|
.isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class)));
|
||||||
|
// @formatter:on
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void preSendWhenAckThenIgnores() {
|
||||||
|
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
|
||||||
|
this.interceptor.preSend(message(), this.channel);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void preSendWhenDisconnectThenIgnores() {
|
||||||
|
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT);
|
||||||
|
this.interceptor.preSend(message(), this.channel);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void preSendWhenHeartbeatThenIgnores() {
|
||||||
|
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT);
|
||||||
|
this.interceptor.preSend(message(), this.channel);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void preSendWhenMessageThenIgnores() {
|
||||||
|
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
|
||||||
|
this.interceptor.preSend(message(), this.channel);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void preSendWhenOtherThenIgnores() {
|
||||||
|
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.OTHER);
|
||||||
|
this.interceptor.preSend(message(), this.channel);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void preSendWhenUnsubscribeThenIgnores() {
|
||||||
|
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.UNSUBSCRIBE);
|
||||||
|
this.interceptor.preSend(message(), this.channel);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Message<String> message() {
|
||||||
|
return MessageBuilder.withPayload("message").copyHeaders(this.messageHeaders.toMap()).build();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -1,5 +1,5 @@
|
||||||
/*
|
/*
|
||||||
* Copyright 2002-2016 the original author or authors.
|
* Copyright 2002-2023 the original author or authors.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -31,6 +31,7 @@ import org.springframework.http.server.ServletServerHttpRequest;
|
||||||
import org.springframework.mock.web.MockHttpServletRequest;
|
import org.springframework.mock.web.MockHttpServletRequest;
|
||||||
import org.springframework.security.web.csrf.CsrfToken;
|
import org.springframework.security.web.csrf.CsrfToken;
|
||||||
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
||||||
|
import org.springframework.security.web.csrf.DeferredCsrfToken;
|
||||||
import org.springframework.web.socket.WebSocketHandler;
|
import org.springframework.web.socket.WebSocketHandler;
|
||||||
|
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
@ -72,10 +73,38 @@ public class CsrfTokenHandshakeInterceptorTests {
|
||||||
@Test
|
@Test
|
||||||
public void beforeHandshake() throws Exception {
|
public void beforeHandshake() throws Exception {
|
||||||
CsrfToken token = new DefaultCsrfToken("header", "param", "token");
|
CsrfToken token = new DefaultCsrfToken("header", "param", "token");
|
||||||
this.httpRequest.setAttribute(CsrfToken.class.getName(), token);
|
this.httpRequest.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(token));
|
||||||
this.interceptor.beforeHandshake(this.request, this.response, this.wsHandler, this.attributes);
|
this.interceptor.beforeHandshake(this.request, this.response, this.wsHandler, this.attributes);
|
||||||
assertThat(this.attributes.keySet()).containsOnly(CsrfToken.class.getName());
|
assertThat(this.attributes.keySet()).containsOnly(CsrfToken.class.getName());
|
||||||
assertThat(this.attributes.values()).containsOnly(token);
|
CsrfToken csrfToken = (CsrfToken) this.attributes.get(CsrfToken.class.getName());
|
||||||
|
assertThat(csrfToken.getHeaderName()).isEqualTo(token.getHeaderName());
|
||||||
|
assertThat(csrfToken.getParameterName()).isEqualTo(token.getParameterName());
|
||||||
|
assertThat(csrfToken.getToken()).isEqualTo(token.getToken());
|
||||||
|
// Ensure the values of the CsrfToken are copied into a new token so the old token
|
||||||
|
// is available for garbage collection.
|
||||||
|
// This is required because the original token could hold a reference to the
|
||||||
|
// HttpServletRequest/Response of the handshake request.
|
||||||
|
assertThat(csrfToken).isNotSameAs(token);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final class TestDeferredCsrfToken implements DeferredCsrfToken {
|
||||||
|
|
||||||
|
private final CsrfToken csrfToken;
|
||||||
|
|
||||||
|
private TestDeferredCsrfToken(CsrfToken csrfToken) {
|
||||||
|
this.csrfToken = csrfToken;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public CsrfToken get() {
|
||||||
|
return this.csrfToken;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isGenerated() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -107,6 +107,7 @@ public final class CsrfFilter extends OncePerRequestFilter {
|
||||||
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
|
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
|
||||||
throws ServletException, IOException {
|
throws ServletException, IOException {
|
||||||
DeferredCsrfToken deferredCsrfToken = this.tokenRepository.loadDeferredToken(request, response);
|
DeferredCsrfToken deferredCsrfToken = this.tokenRepository.loadDeferredToken(request, response);
|
||||||
|
request.setAttribute(DeferredCsrfToken.class.getName(), deferredCsrfToken);
|
||||||
this.requestHandler.handle(request, response, deferredCsrfToken::get);
|
this.requestHandler.handle(request, response, deferredCsrfToken::get);
|
||||||
if (!this.requireCsrfProtectionMatcher.matches(request)) {
|
if (!this.requireCsrfProtectionMatcher.matches(request)) {
|
||||||
if (this.logger.isTraceEnabled()) {
|
if (this.logger.isTraceEnabled()) {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/*
|
/*
|
||||||
* Copyright 2002-2022 the original author or authors.
|
* Copyright 2002-2023 the original author or authors.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -127,11 +127,12 @@ public class CsrfFilterTests {
|
||||||
@Test
|
@Test
|
||||||
public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException {
|
public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException {
|
||||||
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
||||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
|
||||||
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||||
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
|
||||||
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
||||||
|
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||||
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
|
||||||
verifyNoMoreInteractions(this.filterChain);
|
verifyNoMoreInteractions(this.filterChain);
|
||||||
}
|
}
|
||||||
|
@ -139,12 +140,13 @@ public class CsrfFilterTests {
|
||||||
@Test
|
@Test
|
||||||
public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException {
|
public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException {
|
||||||
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
||||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
|
||||||
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||||
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
|
||||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||||
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
|
||||||
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
||||||
|
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||||
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
|
||||||
verifyNoMoreInteractions(this.filterChain);
|
verifyNoMoreInteractions(this.filterChain);
|
||||||
}
|
}
|
||||||
|
@ -152,12 +154,13 @@ public class CsrfFilterTests {
|
||||||
@Test
|
@Test
|
||||||
public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException {
|
public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException {
|
||||||
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
||||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
|
||||||
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||||
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
|
||||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||||
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
|
||||||
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
||||||
|
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||||
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
|
||||||
verifyNoMoreInteractions(this.filterChain);
|
verifyNoMoreInteractions(this.filterChain);
|
||||||
}
|
}
|
||||||
|
@ -166,8 +169,8 @@ public class CsrfFilterTests {
|
||||||
public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter()
|
public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter()
|
||||||
throws ServletException, IOException {
|
throws ServletException, IOException {
|
||||||
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
||||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
|
||||||
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||||
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler();
|
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler();
|
||||||
handler.handle(this.request, this.response, () -> this.token);
|
handler.handle(this.request, this.response, () -> this.token);
|
||||||
CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
|
CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
|
||||||
|
@ -176,6 +179,7 @@ public class CsrfFilterTests {
|
||||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||||
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
|
||||||
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
||||||
|
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||||
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
|
||||||
verifyNoMoreInteractions(this.filterChain);
|
verifyNoMoreInteractions(this.filterChain);
|
||||||
}
|
}
|
||||||
|
@ -183,11 +187,12 @@ public class CsrfFilterTests {
|
||||||
@Test
|
@Test
|
||||||
public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException {
|
public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException {
|
||||||
given(this.requestMatcher.matches(this.request)).willReturn(false);
|
given(this.requestMatcher.matches(this.request)).willReturn(false);
|
||||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
|
||||||
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||||
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
|
||||||
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
||||||
|
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||||
verify(this.filterChain).doFilter(this.request, this.response);
|
verify(this.filterChain).doFilter(this.request, this.response);
|
||||||
verifyNoMoreInteractions(this.deniedHandler);
|
verifyNoMoreInteractions(this.deniedHandler);
|
||||||
}
|
}
|
||||||
|
@ -195,11 +200,12 @@ public class CsrfFilterTests {
|
||||||
@Test
|
@Test
|
||||||
public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException {
|
public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException {
|
||||||
given(this.requestMatcher.matches(this.request)).willReturn(false);
|
given(this.requestMatcher.matches(this.request)).willReturn(false);
|
||||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, true);
|
||||||
.willReturn(new TestDeferredCsrfToken(this.token, true));
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||||
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
|
||||||
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
||||||
|
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||||
verify(this.filterChain).doFilter(this.request, this.response);
|
verify(this.filterChain).doFilter(this.request, this.response);
|
||||||
verifyNoMoreInteractions(this.deniedHandler);
|
verifyNoMoreInteractions(this.deniedHandler);
|
||||||
}
|
}
|
||||||
|
@ -207,8 +213,8 @@ public class CsrfFilterTests {
|
||||||
@Test
|
@Test
|
||||||
public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException {
|
public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException {
|
||||||
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
||||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
|
||||||
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||||
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler();
|
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler();
|
||||||
handler.handle(this.request, this.response, () -> this.token);
|
handler.handle(this.request, this.response, () -> this.token);
|
||||||
CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
|
CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
|
||||||
|
@ -216,6 +222,7 @@ public class CsrfFilterTests {
|
||||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||||
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
|
||||||
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
||||||
|
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||||
verify(this.filterChain).doFilter(this.request, this.response);
|
verify(this.filterChain).doFilter(this.request, this.response);
|
||||||
verifyNoMoreInteractions(this.deniedHandler);
|
verifyNoMoreInteractions(this.deniedHandler);
|
||||||
}
|
}
|
||||||
|
@ -224,8 +231,8 @@ public class CsrfFilterTests {
|
||||||
public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam()
|
public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam()
|
||||||
throws ServletException, IOException {
|
throws ServletException, IOException {
|
||||||
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
||||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
|
||||||
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||||
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler();
|
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler();
|
||||||
handler.handle(this.request, this.response, () -> this.token);
|
handler.handle(this.request, this.response, () -> this.token);
|
||||||
CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
|
CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
|
||||||
|
@ -234,6 +241,7 @@ public class CsrfFilterTests {
|
||||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||||
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
|
||||||
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
||||||
|
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||||
verify(this.filterChain).doFilter(this.request, this.response);
|
verify(this.filterChain).doFilter(this.request, this.response);
|
||||||
verifyNoMoreInteractions(this.deniedHandler);
|
verifyNoMoreInteractions(this.deniedHandler);
|
||||||
}
|
}
|
||||||
|
@ -241,8 +249,8 @@ public class CsrfFilterTests {
|
||||||
@Test
|
@Test
|
||||||
public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException {
|
public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException {
|
||||||
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
||||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
|
||||||
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||||
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler();
|
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler();
|
||||||
handler.handle(this.request, this.response, () -> this.token);
|
handler.handle(this.request, this.response, () -> this.token);
|
||||||
CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
|
CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
|
||||||
|
@ -250,6 +258,7 @@ public class CsrfFilterTests {
|
||||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||||
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
|
||||||
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
||||||
|
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||||
verify(this.filterChain).doFilter(this.request, this.response);
|
verify(this.filterChain).doFilter(this.request, this.response);
|
||||||
verifyNoMoreInteractions(this.deniedHandler);
|
verifyNoMoreInteractions(this.deniedHandler);
|
||||||
verify(this.tokenRepository, never()).saveToken(any(CsrfToken.class), any(HttpServletRequest.class),
|
verify(this.tokenRepository, never()).saveToken(any(CsrfToken.class), any(HttpServletRequest.class),
|
||||||
|
@ -259,8 +268,8 @@ public class CsrfFilterTests {
|
||||||
@Test
|
@Test
|
||||||
public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException {
|
public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException {
|
||||||
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
||||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, true);
|
||||||
.willReturn(new TestDeferredCsrfToken(this.token, true));
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||||
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler();
|
CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler();
|
||||||
handler.handle(this.request, this.response, () -> this.token);
|
handler.handle(this.request, this.response, () -> this.token);
|
||||||
CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
|
CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
|
||||||
|
@ -268,6 +277,7 @@ public class CsrfFilterTests {
|
||||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||||
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
|
||||||
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
||||||
|
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||||
// LazyCsrfTokenRepository requires the response as an attribute
|
// LazyCsrfTokenRepository requires the response as an attribute
|
||||||
assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response);
|
assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response);
|
||||||
verify(this.filterChain).doFilter(this.request, this.response);
|
verify(this.filterChain).doFilter(this.request, this.response);
|
||||||
|
@ -332,11 +342,12 @@ public class CsrfFilterTests {
|
||||||
this.filter = new CsrfFilter(this.tokenRepository);
|
this.filter = new CsrfFilter(this.tokenRepository);
|
||||||
this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
|
this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
|
||||||
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
||||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
|
||||||
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||||
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull();
|
||||||
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
||||||
|
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||||
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN);
|
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN);
|
||||||
verifyNoMoreInteractions(this.filterChain);
|
verifyNoMoreInteractions(this.filterChain);
|
||||||
}
|
}
|
||||||
|
@ -360,22 +371,24 @@ public class CsrfFilterTests {
|
||||||
given(token.getToken()).willReturn(null);
|
given(token.getToken()).willReturn(null);
|
||||||
given(token.getHeaderName()).willReturn(this.token.getHeaderName());
|
given(token.getHeaderName()).willReturn(this.token.getHeaderName());
|
||||||
given(token.getParameterName()).willReturn(this.token.getParameterName());
|
given(token.getParameterName()).willReturn(this.token.getParameterName());
|
||||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(token, false);
|
||||||
.willReturn(new TestDeferredCsrfToken(token, false));
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||||
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
||||||
filter.doFilterInternal(this.request, this.response, this.filterChain);
|
filter.doFilterInternal(this.request, this.response, this.filterChain);
|
||||||
|
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||||
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
|
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void doFilterWhenRequestHandlerThenUsed() throws Exception {
|
public void doFilterWhenRequestHandlerThenUsed() throws Exception {
|
||||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
|
||||||
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||||
CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
|
CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
|
||||||
this.filter = createCsrfFilter(this.tokenRepository);
|
this.filter = createCsrfFilter(this.tokenRepository);
|
||||||
this.filter.setRequestHandler(requestHandler);
|
this.filter.setRequestHandler(requestHandler);
|
||||||
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
||||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||||
|
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||||
verify(this.tokenRepository).loadDeferredToken(this.request, this.response);
|
verify(this.tokenRepository).loadDeferredToken(this.request, this.response);
|
||||||
verify(requestHandler).handle(eq(this.request), eq(this.response), any());
|
verify(requestHandler).handle(eq(this.request), eq(this.response), any());
|
||||||
verify(this.filterChain).doFilter(this.request, this.response);
|
verify(this.filterChain).doFilter(this.request, this.response);
|
||||||
|
@ -384,11 +397,12 @@ public class CsrfFilterTests {
|
||||||
@Test
|
@Test
|
||||||
public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndValidTokenThenSuccess() throws Exception {
|
public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndValidTokenThenSuccess() throws Exception {
|
||||||
given(this.requestMatcher.matches(this.request)).willReturn(false);
|
given(this.requestMatcher.matches(this.request)).willReturn(false);
|
||||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
|
||||||
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||||
assertThat(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
assertThat(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
||||||
assertThat(this.request.getAttribute("_csrf")).isNotNull();
|
assertThat(this.request.getAttribute("_csrf")).isNotNull();
|
||||||
|
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||||
verify(this.filterChain).doFilter(this.request, this.response);
|
verify(this.filterChain).doFilter(this.request, this.response);
|
||||||
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
|
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
|
||||||
|
|
||||||
|
@ -407,10 +421,11 @@ public class CsrfFilterTests {
|
||||||
@Test
|
@Test
|
||||||
public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndRawTokenThenAccessDeniedException() throws Exception {
|
public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndRawTokenThenAccessDeniedException() throws Exception {
|
||||||
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
||||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
|
||||||
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||||
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
||||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||||
|
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||||
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(AccessDeniedException.class));
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(AccessDeniedException.class));
|
||||||
verifyNoMoreInteractions(this.filterChain);
|
verifyNoMoreInteractions(this.filterChain);
|
||||||
}
|
}
|
||||||
|
@ -435,10 +450,11 @@ public class CsrfFilterTests {
|
||||||
requestHandler.setCsrfRequestAttributeName(csrfAttrName);
|
requestHandler.setCsrfRequestAttributeName(csrfAttrName);
|
||||||
filter.setRequestHandler(requestHandler);
|
filter.setRequestHandler(requestHandler);
|
||||||
CsrfToken expectedCsrfToken = mock(CsrfToken.class);
|
CsrfToken expectedCsrfToken = mock(CsrfToken.class);
|
||||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(expectedCsrfToken, true);
|
||||||
.willReturn(new TestDeferredCsrfToken(expectedCsrfToken, true));
|
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||||
|
|
||||||
filter.doFilter(this.request, this.response, this.filterChain);
|
filter.doFilter(this.request, this.response, this.filterChain);
|
||||||
|
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||||
|
|
||||||
verifyNoInteractions(expectedCsrfToken);
|
verifyNoInteractions(expectedCsrfToken);
|
||||||
CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName);
|
CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName);
|
||||||
|
|
Loading…
Reference in New Issue