parent
d42405de42
commit
c306df9b46
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2022 the original author or authors.
|
||||
* Copyright 2002-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -56,6 +56,8 @@ final class WebSocketMessageBrokerSecurityConfiguration
|
|||
|
||||
private static final String SIMPLE_URL_HANDLER_MAPPING_BEAN_NAME = "stompWebSocketHandlerMapping";
|
||||
|
||||
private static final String CSRF_CHANNEL_INTERCEPTOR_BEAN_NAME = "csrfChannelInterceptor";
|
||||
|
||||
private MessageMatcherDelegatingAuthorizationManager b;
|
||||
|
||||
private static final AuthorizationManager<Message<?>> ANY_MESSAGE_AUTHENTICATED = MessageMatcherDelegatingAuthorizationManager
|
||||
|
@ -66,7 +68,7 @@ final class WebSocketMessageBrokerSecurityConfiguration
|
|||
|
||||
private final SecurityContextChannelInterceptor securityContextChannelInterceptor = new SecurityContextChannelInterceptor();
|
||||
|
||||
private final ChannelInterceptor csrfChannelInterceptor = new CsrfChannelInterceptor();
|
||||
private ChannelInterceptor csrfChannelInterceptor = new CsrfChannelInterceptor();
|
||||
|
||||
private AuthorizationChannelInterceptor authorizationChannelInterceptor = new AuthorizationChannelInterceptor(
|
||||
ANY_MESSAGE_AUTHENTICATED);
|
||||
|
@ -86,6 +88,12 @@ final class WebSocketMessageBrokerSecurityConfiguration
|
|||
|
||||
@Override
|
||||
public void configureClientInboundChannel(ChannelRegistration registration) {
|
||||
ChannelInterceptor csrfChannelInterceptor = getBeanOrNull(CSRF_CHANNEL_INTERCEPTOR_BEAN_NAME,
|
||||
ChannelInterceptor.class);
|
||||
if (csrfChannelInterceptor != null) {
|
||||
this.csrfChannelInterceptor = csrfChannelInterceptor;
|
||||
}
|
||||
|
||||
this.authorizationChannelInterceptor
|
||||
.setAuthorizationEventPublisher(new SpringAuthorizationEventPublisher(this.context));
|
||||
this.authorizationChannelInterceptor.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
* Copyright 2002-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -61,6 +61,7 @@ import org.springframework.security.messaging.context.SecurityContextChannelInte
|
|||
import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
|
||||
import org.springframework.security.web.csrf.CsrfToken;
|
||||
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
||||
import org.springframework.security.web.csrf.DeferredCsrfToken;
|
||||
import org.springframework.security.web.csrf.MissingCsrfTokenException;
|
||||
import org.springframework.stereotype.Controller;
|
||||
import org.springframework.test.util.ReflectionTestUtils;
|
||||
|
@ -79,6 +80,7 @@ import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSo
|
|||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
||||
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
|
||||
|
||||
public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
|
||||
|
||||
|
@ -284,7 +286,7 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
|
|||
|
||||
private void assertHandshake(HttpServletRequest request) {
|
||||
TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class);
|
||||
assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(this.token);
|
||||
assertThatCsrfToken(handshakeHandler.attributes.get(CsrfToken.class.getName())).isEqualTo(this.token);
|
||||
assertThat(handshakeHandler.attributes.get(this.sessionAttr))
|
||||
.isEqualTo(request.getSession().getAttribute(this.sessionAttr));
|
||||
}
|
||||
|
@ -306,7 +308,7 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
|
|||
request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket");
|
||||
request.setRequestURI(mapping + "/289/tpyx6mde/websocket");
|
||||
request.getSession().setAttribute(this.sessionAttr, "sessionValue");
|
||||
request.setAttribute(CsrfToken.class.getName(), this.token);
|
||||
request.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token));
|
||||
return request;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
/*
|
||||
* Copyright 2002-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.security.config.annotation.web.socket;
|
||||
|
||||
import org.springframework.security.web.csrf.CsrfToken;
|
||||
import org.springframework.security.web.csrf.DeferredCsrfToken;
|
||||
|
||||
/**
|
||||
* @author Steve Riesenberg
|
||||
*/
|
||||
final class TestDeferredCsrfToken implements DeferredCsrfToken {
|
||||
|
||||
private final CsrfToken csrfToken;
|
||||
|
||||
TestDeferredCsrfToken(CsrfToken csrfToken) {
|
||||
this.csrfToken = csrfToken;
|
||||
}
|
||||
|
||||
@Override
|
||||
public CsrfToken get() {
|
||||
return this.csrfToken;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isGenerated() {
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2022 the original author or authors.
|
||||
* Copyright 2002-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -70,6 +70,7 @@ import org.springframework.security.messaging.context.SecurityContextChannelInte
|
|||
import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
|
||||
import org.springframework.security.web.csrf.CsrfToken;
|
||||
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
||||
import org.springframework.security.web.csrf.DeferredCsrfToken;
|
||||
import org.springframework.security.web.csrf.MissingCsrfTokenException;
|
||||
import org.springframework.stereotype.Controller;
|
||||
import org.springframework.test.util.ReflectionTestUtils;
|
||||
|
@ -92,6 +93,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
|||
import static org.assertj.core.api.Assertions.fail;
|
||||
import static org.mockito.Mockito.atLeastOnce;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
|
||||
|
||||
public class WebSocketMessageBrokerSecurityConfigurationTests {
|
||||
|
||||
|
@ -367,7 +369,7 @@ public class WebSocketMessageBrokerSecurityConfigurationTests {
|
|||
|
||||
private void assertHandshake(HttpServletRequest request) {
|
||||
TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class);
|
||||
assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(this.token);
|
||||
assertThatCsrfToken(handshakeHandler.attributes.get(CsrfToken.class.getName())).isEqualTo(this.token);
|
||||
assertThat(handshakeHandler.attributes.get(this.sessionAttr))
|
||||
.isEqualTo(request.getSession().getAttribute(this.sessionAttr));
|
||||
}
|
||||
|
@ -389,7 +391,7 @@ public class WebSocketMessageBrokerSecurityConfigurationTests {
|
|||
request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket");
|
||||
request.setRequestURI(mapping + "/289/tpyx6mde/websocket");
|
||||
request.getSession().setAttribute(this.sessionAttr, "sessionValue");
|
||||
request.setAttribute(CsrfToken.class.getName(), this.token);
|
||||
request.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token));
|
||||
return request;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2022 the original author or authors.
|
||||
* Copyright 2002-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -61,6 +61,7 @@ import org.springframework.security.test.context.annotation.SecurityTestExecutio
|
|||
import org.springframework.security.test.context.support.WithMockUser;
|
||||
import org.springframework.security.web.csrf.CsrfToken;
|
||||
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
||||
import org.springframework.security.web.csrf.DeferredCsrfToken;
|
||||
import org.springframework.security.web.csrf.InvalidCsrfTokenException;
|
||||
import org.springframework.stereotype.Controller;
|
||||
import org.springframework.test.context.junit.jupiter.SpringExtension;
|
||||
|
@ -77,6 +78,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
|||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.BDDMockito.given;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
|
||||
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
|
||||
|
||||
/**
|
||||
|
@ -381,12 +383,14 @@ public class WebSocketMessageBrokerConfigTests {
|
|||
MockMvc mvc = MockMvcBuilders.webAppContextSetup(context).build();
|
||||
String csrfAttributeName = CsrfToken.class.getName();
|
||||
String customAttributeName = this.getClass().getName();
|
||||
MvcResult result = mvc.perform(get("/app").requestAttr(csrfAttributeName, this.token)
|
||||
.sessionAttr(customAttributeName, "attributeValue")).andReturn();
|
||||
MvcResult result = mvc.perform(
|
||||
get("/app").requestAttr(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token))
|
||||
.sessionAttr(customAttributeName, "attributeValue"))
|
||||
.andReturn();
|
||||
CsrfToken handshakeToken = (CsrfToken) this.testHandshakeHandler.attributes.get(csrfAttributeName);
|
||||
String handshakeValue = (String) this.testHandshakeHandler.attributes.get(customAttributeName);
|
||||
String sessionValue = (String) result.getRequest().getSession().getAttribute(customAttributeName);
|
||||
assertThat(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
|
||||
assertThatCsrfToken(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
|
||||
assertThat(handshakeValue).isEqualTo(sessionValue)
|
||||
.withFailMessage("Explicitly listed session variables are not overridden");
|
||||
}
|
||||
|
@ -398,12 +402,13 @@ public class WebSocketMessageBrokerConfigTests {
|
|||
MockMvc mvc = MockMvcBuilders.webAppContextSetup(context).build();
|
||||
String csrfAttributeName = CsrfToken.class.getName();
|
||||
String customAttributeName = this.getClass().getName();
|
||||
MvcResult result = mvc.perform(get("/app/289/tpyx6mde/websocket").requestAttr(csrfAttributeName, this.token)
|
||||
MvcResult result = mvc.perform(get("/app/289/tpyx6mde/websocket")
|
||||
.requestAttr(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token))
|
||||
.sessionAttr(customAttributeName, "attributeValue")).andReturn();
|
||||
CsrfToken handshakeToken = (CsrfToken) this.testHandshakeHandler.attributes.get(csrfAttributeName);
|
||||
String handshakeValue = (String) this.testHandshakeHandler.attributes.get(customAttributeName);
|
||||
String sessionValue = (String) result.getRequest().getSession().getAttribute(customAttributeName);
|
||||
assertThat(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
|
||||
assertThatCsrfToken(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
|
||||
assertThat(handshakeValue).isEqualTo(sessionValue)
|
||||
.withFailMessage("Explicitly listed session variables are not overridden");
|
||||
}
|
||||
|
@ -526,6 +531,26 @@ public class WebSocketMessageBrokerConfigTests {
|
|||
return SecurityContextHolder.getContextHolderStrategy();
|
||||
}
|
||||
|
||||
private static final class TestDeferredCsrfToken implements DeferredCsrfToken {
|
||||
|
||||
private final CsrfToken csrfToken;
|
||||
|
||||
TestDeferredCsrfToken(CsrfToken csrfToken) {
|
||||
this.csrfToken = csrfToken;
|
||||
}
|
||||
|
||||
@Override
|
||||
public CsrfToken get() {
|
||||
return this.csrfToken;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isGenerated() {
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Controller
|
||||
static class MessageController {
|
||||
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
/*
|
||||
* Copyright 2002-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.security.messaging.web.csrf;
|
||||
|
||||
import java.security.MessageDigest;
|
||||
import java.util.Map;
|
||||
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.messaging.MessageChannel;
|
||||
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
|
||||
import org.springframework.messaging.simp.SimpMessageType;
|
||||
import org.springframework.messaging.support.ChannelInterceptor;
|
||||
import org.springframework.security.crypto.codec.Utf8;
|
||||
import org.springframework.security.messaging.util.matcher.MessageMatcher;
|
||||
import org.springframework.security.messaging.util.matcher.SimpMessageTypeMatcher;
|
||||
import org.springframework.security.web.csrf.CsrfToken;
|
||||
import org.springframework.security.web.csrf.InvalidCsrfTokenException;
|
||||
import org.springframework.security.web.csrf.MissingCsrfTokenException;
|
||||
|
||||
/**
|
||||
* {@link ChannelInterceptor} that validates a CSRF token masked by the
|
||||
* {@link org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler} in
|
||||
* the header of any {@link SimpMessageType#CONNECT} message.
|
||||
*
|
||||
* @author Steve Riesenberg
|
||||
* @since 5.8
|
||||
*/
|
||||
public final class XorCsrfChannelInterceptor implements ChannelInterceptor {
|
||||
|
||||
private final MessageMatcher<Object> matcher = new SimpMessageTypeMatcher(SimpMessageType.CONNECT);
|
||||
|
||||
@Override
|
||||
public Message<?> preSend(Message<?> message, MessageChannel channel) {
|
||||
if (!this.matcher.matches(message)) {
|
||||
return message;
|
||||
}
|
||||
Map<String, Object> sessionAttributes = SimpMessageHeaderAccessor.getSessionAttributes(message.getHeaders());
|
||||
CsrfToken expectedToken = (sessionAttributes != null)
|
||||
? (CsrfToken) sessionAttributes.get(CsrfToken.class.getName()) : null;
|
||||
if (expectedToken == null) {
|
||||
throw new MissingCsrfTokenException(null);
|
||||
}
|
||||
String actualToken = SimpMessageHeaderAccessor.wrap(message)
|
||||
.getFirstNativeHeader(expectedToken.getHeaderName());
|
||||
String actualTokenValue = XorCsrfTokenUtils.getTokenValue(actualToken, expectedToken.getToken());
|
||||
boolean csrfCheckPassed = equalsConstantTime(expectedToken.getToken(), actualTokenValue);
|
||||
if (!csrfCheckPassed) {
|
||||
throw new InvalidCsrfTokenException(expectedToken, actualToken);
|
||||
}
|
||||
return message;
|
||||
}
|
||||
|
||||
/**
|
||||
* Constant time comparison to prevent against timing attacks.
|
||||
* @param expected
|
||||
* @param actual
|
||||
* @return
|
||||
*/
|
||||
private static boolean equalsConstantTime(String expected, String actual) {
|
||||
if (expected == actual) {
|
||||
return true;
|
||||
}
|
||||
if (expected == null || actual == null) {
|
||||
return false;
|
||||
}
|
||||
// Encode after ensure that the string is not null
|
||||
byte[] expectedBytes = Utf8.encode(expected);
|
||||
byte[] actualBytes = Utf8.encode(actual);
|
||||
return MessageDigest.isEqual(expectedBytes, actualBytes);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,72 @@
|
|||
/*
|
||||
* Copyright 2002-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.security.messaging.web.csrf;
|
||||
|
||||
import java.util.Base64;
|
||||
|
||||
import org.springframework.security.crypto.codec.Utf8;
|
||||
|
||||
/**
|
||||
* Copied from
|
||||
* {@link org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler}.
|
||||
*
|
||||
* @see <a href=
|
||||
* "https://github.com/spring-projects/spring-security/issues/12378">gh-12378</a>
|
||||
*/
|
||||
final class XorCsrfTokenUtils {
|
||||
|
||||
private XorCsrfTokenUtils() {
|
||||
}
|
||||
|
||||
static String getTokenValue(String actualToken, String token) {
|
||||
byte[] actualBytes;
|
||||
try {
|
||||
actualBytes = Base64.getUrlDecoder().decode(actualToken);
|
||||
}
|
||||
catch (Exception ex) {
|
||||
return null;
|
||||
}
|
||||
|
||||
byte[] tokenBytes = Utf8.encode(token);
|
||||
int tokenSize = tokenBytes.length;
|
||||
if (actualBytes.length < tokenSize) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// extract token and random bytes
|
||||
int randomBytesSize = actualBytes.length - tokenSize;
|
||||
byte[] xoredCsrf = new byte[tokenSize];
|
||||
byte[] randomBytes = new byte[randomBytesSize];
|
||||
|
||||
System.arraycopy(actualBytes, 0, randomBytes, 0, randomBytesSize);
|
||||
System.arraycopy(actualBytes, randomBytesSize, xoredCsrf, 0, tokenSize);
|
||||
|
||||
byte[] csrfBytes = xorCsrf(randomBytes, xoredCsrf);
|
||||
return Utf8.decode(csrfBytes);
|
||||
}
|
||||
|
||||
private static byte[] xorCsrf(byte[] randomBytes, byte[] csrfBytes) {
|
||||
int len = Math.min(randomBytes.length, csrfBytes.length);
|
||||
byte[] xoredCsrf = new byte[len];
|
||||
System.arraycopy(csrfBytes, 0, xoredCsrf, 0, csrfBytes.length);
|
||||
for (int i = 0; i < len; i++) {
|
||||
xoredCsrf[i] ^= randomBytes[i];
|
||||
}
|
||||
return xoredCsrf;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2015 the original author or authors.
|
||||
* Copyright 2002-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -24,15 +24,18 @@ import org.springframework.http.server.ServerHttpRequest;
|
|||
import org.springframework.http.server.ServerHttpResponse;
|
||||
import org.springframework.http.server.ServletServerHttpRequest;
|
||||
import org.springframework.security.web.csrf.CsrfToken;
|
||||
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
||||
import org.springframework.security.web.csrf.DeferredCsrfToken;
|
||||
import org.springframework.web.socket.WebSocketHandler;
|
||||
import org.springframework.web.socket.server.HandshakeInterceptor;
|
||||
|
||||
/**
|
||||
* Copies a CsrfToken from the HttpServletRequest's attributes to the WebSocket
|
||||
* attributes. This is used as the expected CsrfToken when validating connection requests
|
||||
* to ensure only the same origin connects.
|
||||
* Loads a CsrfToken from the HttpServletRequest and HttpServletResponse to populate the
|
||||
* WebSocket attributes. This is used as the expected CsrfToken when validating connection
|
||||
* requests to ensure only the same origin connects.
|
||||
*
|
||||
* @author Rob Winch
|
||||
* @author Steve Riesenberg
|
||||
* @since 4.0
|
||||
*/
|
||||
public final class CsrfTokenHandshakeInterceptor implements HandshakeInterceptor {
|
||||
|
@ -41,11 +44,19 @@ public final class CsrfTokenHandshakeInterceptor implements HandshakeInterceptor
|
|||
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler,
|
||||
Map<String, Object> attributes) {
|
||||
HttpServletRequest httpRequest = ((ServletServerHttpRequest) request).getServletRequest();
|
||||
CsrfToken token = (CsrfToken) httpRequest.getAttribute(CsrfToken.class.getName());
|
||||
if (token == null) {
|
||||
DeferredCsrfToken deferredCsrfToken = (DeferredCsrfToken) httpRequest
|
||||
.getAttribute(DeferredCsrfToken.class.getName());
|
||||
if (deferredCsrfToken == null) {
|
||||
return true;
|
||||
}
|
||||
attributes.put(CsrfToken.class.getName(), token);
|
||||
CsrfToken csrfToken = deferredCsrfToken.get();
|
||||
// Ensure the values of the CsrfToken are copied into a new token so the old token
|
||||
// is available for garbage collection.
|
||||
// This is required because the original token could hold a reference to the
|
||||
// HttpServletRequest/Response of the handshake request.
|
||||
CsrfToken resolvedCsrfToken = new DefaultCsrfToken(csrfToken.getHeaderName(), csrfToken.getParameterName(),
|
||||
csrfToken.getToken());
|
||||
attributes.put(CsrfToken.class.getName(), resolvedCsrfToken);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,148 @@
|
|||
/*
|
||||
* Copyright 2002-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.security.messaging.web.csrf;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.messaging.MessageChannel;
|
||||
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
|
||||
import org.springframework.messaging.simp.SimpMessageType;
|
||||
import org.springframework.messaging.support.MessageBuilder;
|
||||
import org.springframework.security.web.csrf.CsrfToken;
|
||||
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
||||
import org.springframework.security.web.csrf.InvalidCsrfTokenException;
|
||||
import org.springframework.security.web.csrf.MissingCsrfTokenException;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
/**
|
||||
* Tests for {@link XorCsrfChannelInterceptor}.
|
||||
*
|
||||
* @author Steve Riesenberg
|
||||
*/
|
||||
public class XorCsrfChannelInterceptorTests {
|
||||
|
||||
private static final String XOR_CSRF_TOKEN_VALUE = "wpe7zB62-NCpcA==";
|
||||
|
||||
private static final String INVALID_XOR_CSRF_TOKEN_VALUE = "KneoaygbRZtfHQ==";
|
||||
|
||||
private CsrfToken token;
|
||||
|
||||
private SimpMessageHeaderAccessor messageHeaders;
|
||||
|
||||
private MessageChannel channel;
|
||||
|
||||
private XorCsrfChannelInterceptor interceptor;
|
||||
|
||||
@BeforeEach
|
||||
public void setup() {
|
||||
this.token = new DefaultCsrfToken("header", "param", "token");
|
||||
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
|
||||
this.messageHeaders.setSessionAttributes(new HashMap<>());
|
||||
this.channel = mock(MessageChannel.class);
|
||||
this.interceptor = new XorCsrfChannelInterceptor();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void preSendWhenConnectWithValidTokenThenSuccess() {
|
||||
this.messageHeaders.setNativeHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE);
|
||||
this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token);
|
||||
this.interceptor.preSend(message(), this.channel);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void preSendWhenConnectWithInvalidTokenThenThrowsInvalidCsrfTokenException() {
|
||||
this.messageHeaders.setNativeHeader(this.token.getHeaderName(), INVALID_XOR_CSRF_TOKEN_VALUE);
|
||||
this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token);
|
||||
// @formatter:off
|
||||
assertThatExceptionOfType(InvalidCsrfTokenException.class)
|
||||
.isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class)));
|
||||
// @formatter:on
|
||||
}
|
||||
|
||||
@Test
|
||||
public void preSendWhenConnectWithNoTokenThenThrowsInvalidCsrfTokenException() {
|
||||
this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token);
|
||||
// @formatter:off
|
||||
assertThatExceptionOfType(InvalidCsrfTokenException.class)
|
||||
.isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class)));
|
||||
// @formatter:on
|
||||
}
|
||||
|
||||
@Test
|
||||
public void preSendWhenConnectWithMissingTokenThenThrowsMissingCsrfTokenException() {
|
||||
// @formatter:off
|
||||
assertThatExceptionOfType(MissingCsrfTokenException.class)
|
||||
.isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class)));
|
||||
// @formatter:on
|
||||
}
|
||||
|
||||
@Test
|
||||
public void preSendWhenConnectWithNullSessionAttributesThenThrowsMissingCsrfTokenException() {
|
||||
this.messageHeaders.setSessionAttributes(null);
|
||||
// @formatter:off
|
||||
assertThatExceptionOfType(MissingCsrfTokenException.class)
|
||||
.isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class)));
|
||||
// @formatter:on
|
||||
}
|
||||
|
||||
@Test
|
||||
public void preSendWhenAckThenIgnores() {
|
||||
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
|
||||
this.interceptor.preSend(message(), this.channel);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void preSendWhenDisconnectThenIgnores() {
|
||||
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT);
|
||||
this.interceptor.preSend(message(), this.channel);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void preSendWhenHeartbeatThenIgnores() {
|
||||
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT);
|
||||
this.interceptor.preSend(message(), this.channel);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void preSendWhenMessageThenIgnores() {
|
||||
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
|
||||
this.interceptor.preSend(message(), this.channel);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void preSendWhenOtherThenIgnores() {
|
||||
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.OTHER);
|
||||
this.interceptor.preSend(message(), this.channel);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void preSendWhenUnsubscribeThenIgnores() {
|
||||
this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.UNSUBSCRIBE);
|
||||
this.interceptor.preSend(message(), this.channel);
|
||||
}
|
||||
|
||||
private Message<String> message() {
|
||||
return MessageBuilder.withPayload("message").copyHeaders(this.messageHeaders.toMap()).build();
|
||||
}
|
||||
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2016 the original author or authors.
|
||||
* Copyright 2002-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -31,6 +31,7 @@ import org.springframework.http.server.ServletServerHttpRequest;
|
|||
import org.springframework.mock.web.MockHttpServletRequest;
|
||||
import org.springframework.security.web.csrf.CsrfToken;
|
||||
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
||||
import org.springframework.security.web.csrf.DeferredCsrfToken;
|
||||
import org.springframework.web.socket.WebSocketHandler;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
@ -72,10 +73,38 @@ public class CsrfTokenHandshakeInterceptorTests {
|
|||
@Test
|
||||
public void beforeHandshake() throws Exception {
|
||||
CsrfToken token = new DefaultCsrfToken("header", "param", "token");
|
||||
this.httpRequest.setAttribute(CsrfToken.class.getName(), token);
|
||||
this.httpRequest.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(token));
|
||||
this.interceptor.beforeHandshake(this.request, this.response, this.wsHandler, this.attributes);
|
||||
assertThat(this.attributes.keySet()).containsOnly(CsrfToken.class.getName());
|
||||
assertThat(this.attributes.values()).containsOnly(token);
|
||||
CsrfToken csrfToken = (CsrfToken) this.attributes.get(CsrfToken.class.getName());
|
||||
assertThat(csrfToken.getHeaderName()).isEqualTo(token.getHeaderName());
|
||||
assertThat(csrfToken.getParameterName()).isEqualTo(token.getParameterName());
|
||||
assertThat(csrfToken.getToken()).isEqualTo(token.getToken());
|
||||
// Ensure the values of the CsrfToken are copied into a new token so the old token
|
||||
// is available for garbage collection.
|
||||
// This is required because the original token could hold a reference to the
|
||||
// HttpServletRequest/Response of the handshake request.
|
||||
assertThat(csrfToken).isNotSameAs(token);
|
||||
}
|
||||
|
||||
private static final class TestDeferredCsrfToken implements DeferredCsrfToken {
|
||||
|
||||
private final CsrfToken csrfToken;
|
||||
|
||||
private TestDeferredCsrfToken(CsrfToken csrfToken) {
|
||||
this.csrfToken = csrfToken;
|
||||
}
|
||||
|
||||
@Override
|
||||
public CsrfToken get() {
|
||||
return this.csrfToken;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isGenerated() {
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2022 the original author or authors.
|
||||
* Copyright 2002-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -108,6 +108,7 @@ public final class CsrfFilter extends OncePerRequestFilter {
|
|||
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
|
||||
throws ServletException, IOException {
|
||||
DeferredCsrfToken deferredCsrfToken = this.tokenRepository.loadDeferredToken(request, response);
|
||||
request.setAttribute(DeferredCsrfToken.class.getName(), deferredCsrfToken);
|
||||
this.requestHandler.handle(request, response, deferredCsrfToken::get);
|
||||
if (!this.requireCsrfProtectionMatcher.matches(request)) {
|
||||
if (this.logger.isTraceEnabled()) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2022 the original author or authors.
|
||||
* Copyright 2002-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -126,11 +126,12 @@ public class CsrfFilterTests {
|
|||
@Test
|
||||
public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException {
|
||||
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
||||
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
||||
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
||||
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
||||
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
|
||||
verifyNoMoreInteractions(this.filterChain);
|
||||
}
|
||||
|
@ -138,12 +139,13 @@ public class CsrfFilterTests {
|
|||
@Test
|
||||
public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException {
|
||||
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
||||
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
||||
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
|
||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
||||
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
||||
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
|
||||
verifyNoMoreInteractions(this.filterChain);
|
||||
}
|
||||
|
@ -151,12 +153,13 @@ public class CsrfFilterTests {
|
|||
@Test
|
||||
public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException {
|
||||
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
||||
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
||||
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
|
||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
||||
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
||||
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
|
||||
verifyNoMoreInteractions(this.filterChain);
|
||||
}
|
||||
|
@ -165,13 +168,14 @@ public class CsrfFilterTests {
|
|||
public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter()
|
||||
throws ServletException, IOException {
|
||||
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
||||
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
||||
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
||||
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
|
||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
||||
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
||||
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
|
||||
verifyNoMoreInteractions(this.filterChain);
|
||||
}
|
||||
|
@ -179,11 +183,12 @@ public class CsrfFilterTests {
|
|||
@Test
|
||||
public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException {
|
||||
given(this.requestMatcher.matches(this.request)).willReturn(false);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
||||
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
||||
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
||||
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
||||
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||
verify(this.filterChain).doFilter(this.request, this.response);
|
||||
verifyNoMoreInteractions(this.deniedHandler);
|
||||
}
|
||||
|
@ -191,11 +196,12 @@ public class CsrfFilterTests {
|
|||
@Test
|
||||
public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException {
|
||||
given(this.requestMatcher.matches(this.request)).willReturn(false);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
||||
.willReturn(new TestDeferredCsrfToken(this.token, true));
|
||||
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, true);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
||||
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
||||
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||
verify(this.filterChain).doFilter(this.request, this.response);
|
||||
verifyNoMoreInteractions(this.deniedHandler);
|
||||
}
|
||||
|
@ -203,12 +209,13 @@ public class CsrfFilterTests {
|
|||
@Test
|
||||
public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException {
|
||||
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
||||
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
||||
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||
this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
|
||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
||||
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
||||
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||
verify(this.filterChain).doFilter(this.request, this.response);
|
||||
verifyNoMoreInteractions(this.deniedHandler);
|
||||
}
|
||||
|
@ -217,13 +224,14 @@ public class CsrfFilterTests {
|
|||
public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam()
|
||||
throws ServletException, IOException {
|
||||
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
||||
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
||||
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
|
||||
this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
|
||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
||||
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
||||
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||
verify(this.filterChain).doFilter(this.request, this.response);
|
||||
verifyNoMoreInteractions(this.deniedHandler);
|
||||
}
|
||||
|
@ -231,12 +239,13 @@ public class CsrfFilterTests {
|
|||
@Test
|
||||
public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException {
|
||||
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
||||
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
||||
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
||||
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
||||
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||
verify(this.filterChain).doFilter(this.request, this.response);
|
||||
verifyNoMoreInteractions(this.deniedHandler);
|
||||
verify(this.tokenRepository, never()).saveToken(any(CsrfToken.class), any(HttpServletRequest.class),
|
||||
|
@ -246,12 +255,13 @@ public class CsrfFilterTests {
|
|||
@Test
|
||||
public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException {
|
||||
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
||||
.willReturn(new TestDeferredCsrfToken(this.token, true));
|
||||
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, true);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
||||
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
||||
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||
// LazyCsrfTokenRepository requires the response as an attribute
|
||||
assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response);
|
||||
verify(this.filterChain).doFilter(this.request, this.response);
|
||||
|
@ -316,11 +326,12 @@ public class CsrfFilterTests {
|
|||
this.filter = new CsrfFilter(this.tokenRepository);
|
||||
this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
|
||||
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
||||
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
||||
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
||||
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
||||
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN);
|
||||
verifyNoMoreInteractions(this.filterChain);
|
||||
}
|
||||
|
@ -344,22 +355,24 @@ public class CsrfFilterTests {
|
|||
given(token.getToken()).willReturn(null);
|
||||
given(token.getHeaderName()).willReturn(this.token.getHeaderName());
|
||||
given(token.getParameterName()).willReturn(this.token.getParameterName());
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
||||
.willReturn(new TestDeferredCsrfToken(token, false));
|
||||
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(token, false);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
||||
filter.doFilterInternal(this.request, this.response, this.filterChain);
|
||||
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenRequestHandlerThenUsed() throws Exception {
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
||||
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
||||
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||
CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
|
||||
this.filter = createCsrfFilter(this.tokenRepository);
|
||||
this.filter.setRequestHandler(requestHandler);
|
||||
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||
verify(this.tokenRepository).loadDeferredToken(this.request, this.response);
|
||||
verify(requestHandler).handle(eq(this.request), eq(this.response), any());
|
||||
verify(this.filterChain).doFilter(this.request, this.response);
|
||||
|
@ -368,14 +381,15 @@ public class CsrfFilterTests {
|
|||
@Test
|
||||
public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndValidTokenThenSuccess() throws Exception {
|
||||
given(this.requestMatcher.matches(this.request)).willReturn(false);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
||||
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
||||
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||
XorCsrfTokenRequestAttributeHandler requestHandler = new XorCsrfTokenRequestAttributeHandler();
|
||||
requestHandler.setCsrfRequestAttributeName(this.token.getParameterName());
|
||||
this.filter.setRequestHandler(requestHandler);
|
||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||
assertThat(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
|
||||
assertThat(this.request.getAttribute(this.token.getParameterName())).isNotNull();
|
||||
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||
verify(this.filterChain).doFilter(this.request, this.response);
|
||||
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
|
||||
|
||||
|
@ -394,12 +408,13 @@ public class CsrfFilterTests {
|
|||
@Test
|
||||
public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndRawTokenThenAccessDeniedException() throws Exception {
|
||||
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
||||
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
||||
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||
XorCsrfTokenRequestAttributeHandler requestHandler = new XorCsrfTokenRequestAttributeHandler();
|
||||
this.filter.setRequestHandler(requestHandler);
|
||||
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
||||
this.filter.doFilter(this.request, this.response, this.filterChain);
|
||||
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(AccessDeniedException.class));
|
||||
verifyNoMoreInteractions(this.filterChain);
|
||||
}
|
||||
|
@ -424,10 +439,11 @@ public class CsrfFilterTests {
|
|||
requestHandler.setCsrfRequestAttributeName(csrfAttrName);
|
||||
filter.setRequestHandler(requestHandler);
|
||||
CsrfToken expectedCsrfToken = mock(CsrfToken.class);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
||||
.willReturn(new TestDeferredCsrfToken(expectedCsrfToken, true));
|
||||
DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(expectedCsrfToken, true);
|
||||
given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
|
||||
|
||||
filter.doFilter(this.request, this.response, this.filterChain);
|
||||
assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
|
||||
|
||||
verifyNoInteractions(expectedCsrfToken);
|
||||
CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName);
|
||||
|
|
Loading…
Reference in New Issue