SEC-2830: Provide Same Origin support for SockJS
This commit is contained in:
parent
a27c33754c
commit
6a8475adbb
|
@ -15,9 +15,14 @@
|
|||
*/
|
||||
package org.springframework.security.config.annotation.web.configurers;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
|
||||
import org.springframework.security.access.AccessDeniedException;
|
||||
import org.springframework.security.config.annotation.web.AbstractRequestMatcherRegistry;
|
||||
import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
|
||||
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
|
||||
import org.springframework.security.web.access.AccessDeniedHandler;
|
||||
|
@ -31,6 +36,9 @@ import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
|
|||
import org.springframework.security.web.csrf.MissingCsrfTokenException;
|
||||
import org.springframework.security.web.session.InvalidSessionAccessDeniedHandler;
|
||||
import org.springframework.security.web.session.InvalidSessionStrategy;
|
||||
import org.springframework.security.web.util.matcher.AndRequestMatcher;
|
||||
import org.springframework.security.web.util.matcher.NegatedRequestMatcher;
|
||||
import org.springframework.security.web.util.matcher.OrRequestMatcher;
|
||||
import org.springframework.security.web.util.matcher.RequestMatcher;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
|
@ -66,7 +74,8 @@ import org.springframework.util.Assert;
|
|||
*/
|
||||
public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>> extends AbstractHttpConfigurer<CsrfConfigurer<H>,H> {
|
||||
private CsrfTokenRepository csrfTokenRepository = new HttpSessionCsrfTokenRepository();
|
||||
private RequestMatcher requireCsrfProtectionMatcher;
|
||||
private RequestMatcher requireCsrfProtectionMatcher = CsrfFilter.DEFAULT_CSRF_MATCHER;
|
||||
private List<RequestMatcher> ignoredCsrfProtectionMatchers = new ArrayList<RequestMatcher>();
|
||||
|
||||
/**
|
||||
* Creates a new instance
|
||||
|
@ -102,10 +111,38 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>> extends Abst
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* <p>
|
||||
* Allows specifying {@link HttpServletRequest} that should not use CSRF Protection even if they match the {@link #requireCsrfProtectionMatcher(RequestMatcher)}.
|
||||
* </p>
|
||||
*
|
||||
* <p>
|
||||
* The following will ensure CSRF protection ignores:
|
||||
* </p>
|
||||
* <ul>
|
||||
* <li>Any GET, HEAD, TRACE, OPTIONS (this is the default)</li>
|
||||
* <li>We also explicitly state to ignore any request that starts with "/sockjs/"</li>
|
||||
* </ul>
|
||||
*
|
||||
* <pre>
|
||||
* http
|
||||
* .csrf()
|
||||
* .ignoringAntMatchers("/sockjs/**")
|
||||
* .and()
|
||||
* ...
|
||||
* </pre>
|
||||
*
|
||||
* @since 4.0
|
||||
*/
|
||||
public CsrfConfigurer<H> ignoringAntMatchers(String... antPatterns) {
|
||||
return new IgnoreCsrfProtectionRegistry().antMatchers(antPatterns).and();
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Override
|
||||
public void configure(H http) throws Exception {
|
||||
CsrfFilter filter = new CsrfFilter(csrfTokenRepository);
|
||||
RequestMatcher requireCsrfProtectionMatcher = getRequireCsrfProtectionMatcher();
|
||||
if(requireCsrfProtectionMatcher != null) {
|
||||
filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher);
|
||||
}
|
||||
|
@ -125,6 +162,18 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>> extends Abst
|
|||
http.addFilter(filter);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the final {@link RequestMatcher} to use by combining the {@link #requireCsrfProtectionMatcher(RequestMatcher)} and any {@link #ignore()}.
|
||||
*
|
||||
* @return the {@link RequestMatcher} to use
|
||||
*/
|
||||
private RequestMatcher getRequireCsrfProtectionMatcher() {
|
||||
if(ignoredCsrfProtectionMatchers.isEmpty()) {
|
||||
return requireCsrfProtectionMatcher;
|
||||
}
|
||||
return new AndRequestMatcher(requireCsrfProtectionMatcher, new NegatedRequestMatcher(new OrRequestMatcher(ignoredCsrfProtectionMatchers)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the default {@link AccessDeniedHandler} from the
|
||||
* {@link ExceptionHandlingConfigurer#getAccessDeniedHandler()} or create a
|
||||
|
@ -190,4 +239,25 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>> extends Abst
|
|||
handlers.put(MissingCsrfTokenException.class, invalidSessionDeniedHandler);
|
||||
return new DelegatingAccessDeniedHandler(handlers, defaultAccessDeniedHandler);
|
||||
}
|
||||
|
||||
/**
|
||||
* Allows registering {@link RequestMatcher} instances that should be
|
||||
* ignored (even if the {@link HttpServletRequest} matches the
|
||||
* {@link CsrfConfigurer#requireCsrfProtectionMatcher(RequestMatcher)}.
|
||||
*
|
||||
* @author Rob Winch
|
||||
* @since 4.0
|
||||
*/
|
||||
private class IgnoreCsrfProtectionRegistry extends AbstractRequestMatcherRegistry<IgnoreCsrfProtectionRegistry>{
|
||||
|
||||
public CsrfConfigurer<H> and() {
|
||||
return CsrfConfigurer.this;
|
||||
}
|
||||
|
||||
protected IgnoreCsrfProtectionRegistry chainRequestMatchers(
|
||||
List<RequestMatcher> requestMatchers) {
|
||||
ignoredCsrfProtectionMatchers.addAll(requestMatchers);
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -15,6 +15,9 @@
|
|||
*/
|
||||
package org.springframework.security.config.annotation.web.socket;
|
||||
|
||||
import org.springframework.beans.factory.SmartInitializingSingleton;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.context.ApplicationContext;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.core.Ordered;
|
||||
import org.springframework.core.annotation.Order;
|
||||
|
@ -28,11 +31,20 @@ import org.springframework.security.messaging.access.intercept.ChannelSecurityIn
|
|||
import org.springframework.security.messaging.access.intercept.MessageSecurityMetadataSource;
|
||||
import org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolver;
|
||||
import org.springframework.security.messaging.context.SecurityContextChannelInterceptor;
|
||||
import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
|
||||
import org.springframework.security.messaging.web.socket.server.CsrfTokenHandshakeInterceptor;
|
||||
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
|
||||
import org.springframework.web.socket.config.annotation.AbstractWebSocketMessageBrokerConfigurer;
|
||||
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
|
||||
import org.springframework.web.socket.server.HandshakeInterceptor;
|
||||
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
|
||||
import org.springframework.web.socket.sockjs.SockJsService;
|
||||
import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler;
|
||||
import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Allows configuring WebSocket Authorization.
|
||||
|
@ -58,9 +70,12 @@ import java.util.List;
|
|||
* @author Rob Winch
|
||||
*/
|
||||
@Order(Ordered.HIGHEST_PRECEDENCE + 100)
|
||||
public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends AbstractWebSocketMessageBrokerConfigurer {
|
||||
public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends AbstractWebSocketMessageBrokerConfigurer
|
||||
implements SmartInitializingSingleton {
|
||||
private final WebSocketMessageSecurityMetadataSourceRegistry inboundRegistry = new WebSocketMessageSecurityMetadataSourceRegistry();
|
||||
|
||||
private ApplicationContext context;
|
||||
|
||||
public void registerStompEndpoints(StompEndpointRegistry registry) {}
|
||||
|
||||
@Override
|
||||
|
@ -69,16 +84,34 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends A
|
|||
argumentResolvers.add(new AuthenticationPrincipalArgumentResolver());
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public final void configureClientInboundChannel(ChannelRegistration registration) {
|
||||
ChannelSecurityInterceptor inboundChannelSecurity = inboundChannelSecurity();
|
||||
registration.setInterceptors(securityContextChannelInterceptor());
|
||||
if(sameOriginEnforced()) {
|
||||
registration.setInterceptors(csrfChannelInterceptor());
|
||||
}
|
||||
if(inboundRegistry.containsMapping()) {
|
||||
registration.setInterceptors(securityContextChannelInterceptor(),inboundChannelSecurity);
|
||||
registration.setInterceptors(inboundChannelSecurity);
|
||||
}
|
||||
customizeClientInboundChannel(registration);
|
||||
}
|
||||
|
||||
/**
|
||||
* <p>
|
||||
* Determines if a CSRF token is required for connecting. This protects against remote sites from connecting to the
|
||||
* application and being able to read/write data over the connection. The default is true.
|
||||
* </p>
|
||||
* <p>
|
||||
* Subclasses can override this method to disable CSRF protection
|
||||
* </p>
|
||||
*
|
||||
* @return true if a CSRF is required for connecting, else false
|
||||
*/
|
||||
protected boolean sameOriginEnforced() {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Allows subclasses to customize the configuration of the {@link ChannelRegistration}.
|
||||
*
|
||||
|
@ -87,6 +120,11 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends A
|
|||
protected void customizeClientInboundChannel(ChannelRegistration registration) {
|
||||
}
|
||||
|
||||
@Bean
|
||||
public CsrfChannelInterceptor csrfChannelInterceptor() {
|
||||
return new CsrfChannelInterceptor();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public ChannelSecurityInterceptor inboundChannelSecurity() {
|
||||
ChannelSecurityInterceptor channelSecurityInterceptor = new ChannelSecurityInterceptor(inboundMessageSecurityMetadataSource());
|
||||
|
@ -125,4 +163,47 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends A
|
|||
return super.containsMapping();
|
||||
}
|
||||
}
|
||||
|
||||
@Autowired
|
||||
public void setApplicationContext(ApplicationContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
public void afterSingletonsInstantiated() {
|
||||
if(!sameOriginEnforced()) {
|
||||
return;
|
||||
}
|
||||
|
||||
String beanName = "stompWebSocketHandlerMapping";
|
||||
SimpleUrlHandlerMapping mapping = context.getBean(beanName, SimpleUrlHandlerMapping.class);
|
||||
Map<String, Object> mappings = mapping.getHandlerMap();
|
||||
for(Object object : mappings.values()) {
|
||||
if(object instanceof SockJsHttpRequestHandler) {
|
||||
SockJsHttpRequestHandler sockjsHandler = (SockJsHttpRequestHandler) object;
|
||||
SockJsService sockJsService = sockjsHandler.getSockJsService();
|
||||
if(!(sockJsService instanceof TransportHandlingSockJsService)) {
|
||||
throw new IllegalStateException("sockJsService must be instance of TransportHandlingSockJsService got " + sockJsService);
|
||||
}
|
||||
|
||||
TransportHandlingSockJsService transportHandlingSockJsService = (TransportHandlingSockJsService) sockJsService;
|
||||
List<HandshakeInterceptor> handshakeInterceptors = transportHandlingSockJsService.getHandshakeInterceptors();
|
||||
List<HandshakeInterceptor> interceptorsToSet = new ArrayList<HandshakeInterceptor>(handshakeInterceptors.size() + 1);
|
||||
interceptorsToSet.add(new CsrfTokenHandshakeInterceptor());
|
||||
interceptorsToSet.addAll(handshakeInterceptors);
|
||||
|
||||
transportHandlingSockJsService.setHandshakeInterceptors(interceptorsToSet);
|
||||
}
|
||||
else if(object instanceof WebSocketHttpRequestHandler) {
|
||||
WebSocketHttpRequestHandler handler = (WebSocketHttpRequestHandler) object;
|
||||
List<HandshakeInterceptor> handshakeInterceptors = handler.getHandshakeInterceptors();
|
||||
List<HandshakeInterceptor> interceptorsToSet = new ArrayList<HandshakeInterceptor>(handshakeInterceptors.size() + 1);
|
||||
interceptorsToSet.add(new CsrfTokenHandshakeInterceptor());
|
||||
interceptorsToSet.addAll(handshakeInterceptors);
|
||||
|
||||
handler.setHandshakeInterceptors(interceptorsToSet);
|
||||
} else {
|
||||
throw new IllegalStateException("Bean " + beanName + " is expected to contain mappings to either a SockJsHttpRequestHandler or a WebSocketHttpRequestHandler but got " + object);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -33,6 +33,8 @@ import org.springframework.security.messaging.access.intercept.ChannelSecurityIn
|
|||
import org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolver;
|
||||
import org.springframework.security.messaging.context.SecurityContextChannelInterceptor;
|
||||
import org.springframework.security.messaging.util.matcher.SimpDestinationMessageMatcher;
|
||||
import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
|
||||
import org.springframework.security.messaging.web.socket.server.CsrfTokenHandshakeInterceptor;
|
||||
import org.springframework.util.StringUtils;
|
||||
import org.springframework.util.xml.DomUtils;
|
||||
import org.w3c.dom.Element;
|
||||
|
@ -152,7 +154,8 @@ public final class MessageSecurityBeanDefinitionParser implements BeanDefinition
|
|||
String[] beanNames = registry.getBeanDefinitionNames();
|
||||
for(String beanName : beanNames) {
|
||||
BeanDefinition bd = registry.getBeanDefinition(beanName);
|
||||
if(bd.getBeanClassName().equals(SimpAnnotationMethodMessageHandler.class.getName())) {
|
||||
String beanClassName = bd.getBeanClassName();
|
||||
if(beanClassName.equals(SimpAnnotationMethodMessageHandler.class.getName())) {
|
||||
PropertyValue current = bd.getPropertyValues().getPropertyValue(CUSTOM_ARG_RESOLVERS_PROP);
|
||||
ManagedList<Object> argResolvers = new ManagedList<Object>();
|
||||
if(current != null) {
|
||||
|
@ -161,6 +164,13 @@ public final class MessageSecurityBeanDefinitionParser implements BeanDefinition
|
|||
argResolvers.add(new RootBeanDefinition(AuthenticationPrincipalArgumentResolver.class));
|
||||
bd.getPropertyValues().add(CUSTOM_ARG_RESOLVERS_PROP, argResolvers);
|
||||
}
|
||||
else if(beanClassName.equals("org.springframework.web.socket.server.support.WebSocketHttpRequestHandler")) {
|
||||
addCsrfTokenHandshakeInterceptor(bd);
|
||||
} else if(beanClassName.equals("org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService")) {
|
||||
addCsrfTokenHandshakeInterceptor(bd);
|
||||
} else if(beanClassName.equals("org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService")) {
|
||||
addCsrfTokenHandshakeInterceptor(bd);
|
||||
}
|
||||
}
|
||||
|
||||
if(!registry.containsBeanDefinition(CLIENT_INBOUND_CHANNEL_BEAN_ID)) {
|
||||
|
@ -168,6 +178,7 @@ public final class MessageSecurityBeanDefinitionParser implements BeanDefinition
|
|||
}
|
||||
ManagedList<Object> interceptors = new ManagedList();
|
||||
interceptors.add(new RootBeanDefinition(SecurityContextChannelInterceptor.class));
|
||||
interceptors.add(new RootBeanDefinition(CsrfChannelInterceptor.class));
|
||||
interceptors.add(registry.getBeanDefinition(inboundSecurityInterceptorId));
|
||||
|
||||
BeanDefinition inboundChannel = registry.getBeanDefinition(CLIENT_INBOUND_CHANNEL_BEAN_ID);
|
||||
|
@ -180,6 +191,14 @@ public final class MessageSecurityBeanDefinitionParser implements BeanDefinition
|
|||
inboundChannel.getPropertyValues().add(INTERCEPTORS_PROP, interceptors);
|
||||
}
|
||||
|
||||
private void addCsrfTokenHandshakeInterceptor(BeanDefinition bd) {
|
||||
String interceptorPropertyName = "handshakeInterceptors";
|
||||
ManagedList<? super Object> interceptors = new ManagedList<Object>();
|
||||
interceptors.add(new RootBeanDefinition(CsrfTokenHandshakeInterceptor.class));
|
||||
interceptors.addAll((ManagedList<Object>)bd.getPropertyValues().get(interceptorPropertyName));
|
||||
bd.getPropertyValues().add(interceptorPropertyName, interceptors);
|
||||
}
|
||||
|
||||
public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
|
||||
|
||||
}
|
||||
|
|
|
@ -12,12 +12,24 @@ import org.springframework.http.server.ServerHttpRequest
|
|||
import org.springframework.http.server.ServerHttpResponse
|
||||
import org.springframework.messaging.handler.annotation.MessageMapping
|
||||
import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver
|
||||
import org.springframework.messaging.simp.SimpMessageType
|
||||
import org.springframework.mock.web.MockHttpServletRequest
|
||||
import org.springframework.mock.web.MockHttpServletResponse
|
||||
import org.springframework.security.core.Authentication
|
||||
import org.springframework.security.core.annotation.AuthenticationPrincipal
|
||||
import org.springframework.security.web.csrf.CsrfToken
|
||||
import org.springframework.security.web.csrf.DefaultCsrfToken
|
||||
import org.springframework.security.web.csrf.MissingCsrfTokenException
|
||||
import org.springframework.stereotype.Controller
|
||||
import org.springframework.web.servlet.HandlerMapping
|
||||
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping
|
||||
import org.springframework.web.socket.WebSocketHandler
|
||||
import org.springframework.web.socket.server.HandshakeFailureException
|
||||
import org.springframework.web.socket.server.HandshakeHandler
|
||||
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor
|
||||
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler
|
||||
import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler
|
||||
import org.springframework.web.socket.sockjs.transport.handler.SockJsWebSocketHandler
|
||||
|
||||
import static org.mockito.Mockito.*
|
||||
|
||||
|
@ -37,6 +49,7 @@ import org.springframework.security.core.context.SecurityContextHolder
|
|||
*/
|
||||
class MessagesConfigTests extends AbstractXmlConfigTests {
|
||||
Authentication messageUser = new TestingAuthenticationToken('user','pass','ROLE_USER')
|
||||
boolean useSockJS = false
|
||||
|
||||
def cleanup() {
|
||||
SecurityContextHolder.clearContext()
|
||||
|
@ -93,6 +106,89 @@ class MessagesConfigTests extends AbstractXmlConfigTests {
|
|||
controller.authenticationPrincipal == messageUser.name
|
||||
}
|
||||
|
||||
def 'messages of type CONNECT use CsrfTokenHandshakeInterceptor'() {
|
||||
setup:
|
||||
def id = 'authenticationController'
|
||||
bean(id,MyController)
|
||||
bean('inPostProcessor',InboundExecutorPostProcessor)
|
||||
messages {
|
||||
'message-interceptor'(pattern:'/**',access:'permitAll')
|
||||
}
|
||||
|
||||
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT)
|
||||
Message<?> message = message(headers,'/authentication')
|
||||
WebSocketHttpRequestHandler handler = appContext.getBean(WebSocketHttpRequestHandler)
|
||||
MockHttpServletRequest request = new MockHttpServletRequest()
|
||||
String sessionAttr = "sessionAttr"
|
||||
request.getSession().setAttribute(sessionAttr,"sessionValue")
|
||||
|
||||
CsrfToken token = new DefaultCsrfToken("header", "param", "token")
|
||||
request.setAttribute(CsrfToken.name, token)
|
||||
|
||||
when:
|
||||
handler.handleRequest(request , new MockHttpServletResponse())
|
||||
TestHandshakeHandler handshakeHandler = appContext.getBean(TestHandshakeHandler)
|
||||
|
||||
then: 'CsrfToken is populated'
|
||||
handshakeHandler.attributes.get(CsrfToken.name) == token
|
||||
|
||||
and: 'Explicitly listed HandshakeInterceptor are not overridden'
|
||||
handshakeHandler.attributes.get(sessionAttr) == request.getSession().getAttribute(sessionAttr)
|
||||
}
|
||||
|
||||
def 'messages of type CONNECT use CsrfTokenHandshakeInterceptor with SockJS'() {
|
||||
setup:
|
||||
useSockJS = true
|
||||
def id = 'authenticationController'
|
||||
bean(id,MyController)
|
||||
bean('inPostProcessor',InboundExecutorPostProcessor)
|
||||
messages {
|
||||
'message-interceptor'(pattern:'/**',access:'permitAll')
|
||||
}
|
||||
|
||||
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT)
|
||||
Message<?> message = message(headers,'/authentication')
|
||||
SockJsHttpRequestHandler handler = appContext.getBean(SockJsHttpRequestHandler)
|
||||
MockHttpServletRequest request = new MockHttpServletRequest()
|
||||
String sessionAttr = "sessionAttr"
|
||||
request.getSession().setAttribute(sessionAttr,"sessionValue")
|
||||
|
||||
CsrfToken token = new DefaultCsrfToken("header", "param", "token")
|
||||
request.setAttribute(CsrfToken.name, token)
|
||||
|
||||
request.setMethod("GET")
|
||||
request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket")
|
||||
|
||||
when:
|
||||
handler.handleRequest(request , new MockHttpServletResponse())
|
||||
TestHandshakeHandler handshakeHandler = appContext.getBean(TestHandshakeHandler)
|
||||
|
||||
then: 'CsrfToken is populated'
|
||||
handshakeHandler.attributes?.get(CsrfToken.name) == token
|
||||
|
||||
and: 'Explicitly listed HandshakeInterceptor are not overridden'
|
||||
handshakeHandler.attributes?.get(sessionAttr) == request.getSession().getAttribute(sessionAttr)
|
||||
}
|
||||
|
||||
def 'messages of type CONNECT require valid CsrfToken'() {
|
||||
setup:
|
||||
def id = 'authenticationController'
|
||||
bean(id,MyController)
|
||||
bean('inPostProcessor',InboundExecutorPostProcessor)
|
||||
messages {
|
||||
'message-interceptor'(pattern:'/**',access:'permitAll')
|
||||
}
|
||||
|
||||
when: 'message of type CONNECTION is sent without CsrfTOken'
|
||||
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT)
|
||||
Message<?> message = message(headers,'/authentication')
|
||||
clientInboundChannel.send(message)
|
||||
|
||||
then: 'CSRF Protection blocks the Message'
|
||||
MessageDeliveryException expected = thrown()
|
||||
expected.cause instanceof MissingCsrfTokenException
|
||||
}
|
||||
|
||||
def 'messages with no id does not override customArgumentResolvers'() {
|
||||
setup:
|
||||
def id = 'authenticationController'
|
||||
|
@ -201,6 +297,12 @@ class MessagesConfigTests extends AbstractXmlConfigTests {
|
|||
'websocket:transport' {}
|
||||
'websocket:stomp-endpoint'(path:'/app') {
|
||||
'websocket:handshake-handler'(ref:'testHandler') {}
|
||||
'websocket:handshake-interceptors' {
|
||||
'b:bean'('class':HttpSessionHandshakeInterceptor.name) {}
|
||||
}
|
||||
if(useSockJS) {
|
||||
'websocket:sockjs' {}
|
||||
}
|
||||
}
|
||||
'websocket:simple-broker'(prefix:"/queue, /topic"){}
|
||||
}
|
||||
|
@ -214,6 +316,11 @@ class MessagesConfigTests extends AbstractXmlConfigTests {
|
|||
|
||||
def message(String destination) {
|
||||
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create()
|
||||
message(headers, destination)
|
||||
}
|
||||
|
||||
def message(SimpMessageHeaderAccessor headers, String destination) {
|
||||
messageUser = new TestingAuthenticationToken('user','pass','ROLE_USER')
|
||||
headers.sessionId = '123'
|
||||
headers.sessionAttributes = [:]
|
||||
headers.destination = destination
|
||||
|
@ -257,8 +364,15 @@ class MessagesConfigTests extends AbstractXmlConfigTests {
|
|||
}
|
||||
|
||||
static class TestHandshakeHandler implements HandshakeHandler {
|
||||
@Override
|
||||
Map<String, Object> attributes;
|
||||
|
||||
boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws HandshakeFailureException {
|
||||
this.attributes = attributes
|
||||
if(wsHandler instanceof SockJsWebSocketHandler) {
|
||||
// work around SPR-12716
|
||||
SockJsWebSocketHandler sockJs = (SockJsWebSocketHandler) wsHandler;
|
||||
this.attributes = sockJs.sockJsSession.attributes
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
|
|
@ -122,6 +122,120 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
|
|||
assertThat(context.getBean(MyController.class).authenticationPrincipal).isEqualTo((String) messageUser.getPrincipal());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void addsAuthenticationPrincipalResolverWhenNoAuthorization() throws InterruptedException {
|
||||
loadConfig(NoInboundSecurityConfig.class);
|
||||
|
||||
MessageChannel messageChannel = clientInboundChannel();
|
||||
Message<String> message = message("/permitAll/authentication");
|
||||
messageChannel.send(message);
|
||||
|
||||
assertThat(context.getBean(MyController.class).authenticationPrincipal).isEqualTo((String) messageUser.getPrincipal());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void addsCsrfProtectionWhenNoAuthorization() throws InterruptedException {
|
||||
loadConfig(NoInboundSecurityConfig.class);
|
||||
|
||||
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
|
||||
Message<?> message = message(headers, "/authentication");
|
||||
MessageChannel messageChannel = clientInboundChannel();
|
||||
|
||||
try {
|
||||
messageChannel.send(message);
|
||||
fail("Expected Exception");
|
||||
} catch(MessageDeliveryException success) {
|
||||
assertThat(success.getCause()).isInstanceOf(MissingCsrfTokenException.class);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void csrfProtectionForConnect() throws InterruptedException {
|
||||
loadConfig(SockJsSecurityConfig.class);
|
||||
|
||||
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
|
||||
Message<?> message = message(headers, "/authentication");
|
||||
MessageChannel messageChannel = clientInboundChannel();
|
||||
|
||||
try {
|
||||
messageChannel.send(message);
|
||||
fail("Expected Exception");
|
||||
} catch(MessageDeliveryException success) {
|
||||
assertThat(success.getCause()).isInstanceOf(MissingCsrfTokenException.class);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void csrfProtectionDisabledForConnect() throws InterruptedException {
|
||||
loadConfig(CsrfDisabledSockJsSecurityConfig.class);
|
||||
|
||||
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
|
||||
Message<?> message = message(headers, "/permitAll/connect");
|
||||
MessageChannel messageChannel = clientInboundChannel();
|
||||
|
||||
messageChannel.send(message);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void messagesConnectUseCsrfTokenHandshakeInterceptor() throws Exception {
|
||||
|
||||
loadConfig(SockJsSecurityConfig.class);
|
||||
|
||||
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
|
||||
Message<?> message = message(headers, "/authentication");
|
||||
MockHttpServletRequest request = sockjsHttpRequest("/chat");
|
||||
HttpRequestHandler handler = handler(request);
|
||||
|
||||
handler.handleRequest(request, new MockHttpServletResponse());
|
||||
|
||||
assertHandshake(request);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void messagesConnectUseCsrfTokenHandshakeInterceptorMultipleMappings() throws Exception {
|
||||
loadConfig(SockJsSecurityConfig.class);
|
||||
|
||||
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
|
||||
Message<?> message = message(headers, "/authentication");
|
||||
MockHttpServletRequest request = sockjsHttpRequest("/other");
|
||||
HttpRequestHandler handler = handler(request);
|
||||
|
||||
handler.handleRequest(request, new MockHttpServletResponse());
|
||||
|
||||
assertHandshake(request);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void messagesConnectWebSocketUseCsrfTokenHandshakeInterceptor() throws Exception {
|
||||
loadConfig(WebSocketSecurityConfig.class);
|
||||
|
||||
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
|
||||
Message<?> message = message(headers, "/authentication");
|
||||
MockHttpServletRequest request = websocketHttpRequest("/websocket");
|
||||
HttpRequestHandler handler = handler(request);
|
||||
|
||||
handler.handleRequest(request, new MockHttpServletResponse());
|
||||
|
||||
assertHandshake(request);
|
||||
}
|
||||
|
||||
private void assertHandshake(HttpServletRequest request) {
|
||||
TestHandshakeHandler handshakeHandler = context.getBean(TestHandshakeHandler.class);
|
||||
assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(token);
|
||||
assertThat(handshakeHandler.attributes.get(sessionAttr)).isEqualTo(request.getSession().getAttribute(sessionAttr));
|
||||
}
|
||||
|
||||
private HttpRequestHandler handler(HttpServletRequest request) throws Exception {
|
||||
HandlerMapping handlerMapping = context.getBean(HandlerMapping.class);
|
||||
return (HttpRequestHandler) handlerMapping.getHandler(request).getHandler();
|
||||
}
|
||||
|
||||
private MockHttpServletRequest websocketHttpRequest(String mapping) {
|
||||
MockHttpServletRequest request = sockjsHttpRequest(mapping);
|
||||
request.setRequestURI(mapping);
|
||||
return request;
|
||||
}
|
||||
|
||||
private MockHttpServletRequest sockjsHttpRequest(String mapping) {
|
||||
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||
request.setMethod("GET");
|
||||
|
@ -255,6 +369,75 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Configuration
|
||||
@EnableWebSocketMessageBroker
|
||||
@Import(SyncExecutorConfig.class)
|
||||
static class NoInboundSecurityConfig extends AbstractSecurityWebSocketMessageBrokerConfigurer {
|
||||
|
||||
public void registerStompEndpoints(StompEndpointRegistry registry) {
|
||||
registry
|
||||
.addEndpoint("/other")
|
||||
.withSockJS()
|
||||
.setInterceptors(new HttpSessionHandshakeInterceptor());
|
||||
|
||||
registry
|
||||
.addEndpoint("/chat")
|
||||
.withSockJS()
|
||||
.setInterceptors(new HttpSessionHandshakeInterceptor());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void configureMessageBroker(MessageBrokerRegistry registry) {
|
||||
registry.enableSimpleBroker("/queue/", "/topic/");
|
||||
registry.setApplicationDestinationPrefixes("/permitAll", "/denyAll");
|
||||
}
|
||||
|
||||
@Bean
|
||||
public MyController myController() {
|
||||
return new MyController();
|
||||
}
|
||||
}
|
||||
|
||||
@Configuration
|
||||
static class CsrfDisabledSockJsSecurityConfig extends SockJsSecurityConfig {
|
||||
|
||||
@Override
|
||||
protected boolean sameOriginEnforced() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@Configuration
|
||||
@EnableWebSocketMessageBroker
|
||||
@Import(SyncExecutorConfig.class)
|
||||
static class WebSocketSecurityConfig extends AbstractSecurityWebSocketMessageBrokerConfigurer {
|
||||
|
||||
public void registerStompEndpoints(StompEndpointRegistry registry) {
|
||||
registry
|
||||
.addEndpoint("/websocket")
|
||||
.setHandshakeHandler(testHandshakeHandler())
|
||||
.addInterceptors(new HttpSessionHandshakeInterceptor());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {
|
||||
messages
|
||||
.simpDestMatchers("/permitAll/**").permitAll()
|
||||
.anyMessage().denyAll();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public TestHandshakeHandler testHandshakeHandler() {
|
||||
return new TestHandshakeHandler();
|
||||
}
|
||||
}
|
||||
|
||||
@Configuration
|
||||
static class SyncExecutorConfig {
|
||||
@Bean
|
||||
|
|
|
@ -7803,6 +7803,8 @@ The messages attribute has two different modes. If the <<nsa-messages-id>> is no
|
|||
* Ensure that any SimpAnnotationMethodMessageHandler has the AuthenticationPrincipalArgumentResolver registered as a custom argument resolver. This allows the use of `@AuthenticationPrincipal` to resolve the principal of the current `Authentication`
|
||||
* Ensures that the SecurityContextChannelInterceptor is automatically registered for the clientInboundChannel. This populates the SecurityContextHolder with the user that is found in the Message
|
||||
* Ensures that a ChannelSecurityInterceptor is registered with the clientInboundChannel. This allows authorization rules to be specified for a message.
|
||||
* Ensures that a CsrfChannelInterceptor is registered with the clientInboundChannel. This ensures that only requests from the original domain are enabled.
|
||||
* Ensures that a CsrfTokenHandshakeInterceptor is registered with WebSocketHttpRequestHandler, TransportHandlingSockJsService, or DefaultSockJsService. This ensures that the expected CsrfToken from the HttpServletRequest is copied into the WebSocket Session attributes.
|
||||
|
||||
If additional control is necessary, the id can be specified and a ChannelSecurityInterceptor will be assigned to the specified id. All the wiring with Spring's messaging infrastructure can then be done manually. This is more cumbersome, but provides greater control over the configuration.
|
||||
|
||||
|
|
|
@ -9,6 +9,10 @@ dependencies {
|
|||
"org.springframework:spring-expression:$springVersion",
|
||||
"org.springframework:spring-messaging:$springVersion"
|
||||
|
||||
optional project(':spring-security-web'),
|
||||
"org.springframework:spring-websocket:$springVersion",
|
||||
"javax.servlet:javax.servlet-api:$servletApiVersion"
|
||||
|
||||
testCompile project(':spring-security-core').sourceSets.test.output,
|
||||
"commons-codec:commons-codec:$commonsCodecVersion",
|
||||
"org.slf4j:jcl-over-slf4j:$slf4jVersion",
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Copyright 2002-2015 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.security.messaging.web.csrf;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.messaging.MessageChannel;
|
||||
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
|
||||
import org.springframework.messaging.simp.SimpMessageType;
|
||||
import org.springframework.messaging.support.ChannelInterceptorAdapter;
|
||||
import org.springframework.security.messaging.util.matcher.MessageMatcher;
|
||||
import org.springframework.security.messaging.util.matcher.SimpMessageTypeMatcher;
|
||||
import org.springframework.security.web.csrf.CsrfToken;
|
||||
import org.springframework.security.web.csrf.InvalidCsrfTokenException;
|
||||
import org.springframework.security.web.csrf.MissingCsrfTokenException;
|
||||
|
||||
/**
|
||||
* {@link ChannelInterceptorAdapter} that validates that a valid CSRF is included in the header of any
|
||||
* {@link SimpMessageType#CONNECT} message. The expected {@link CsrfToken} is populated by CsrfTokenHandshakeInterceptor.
|
||||
*
|
||||
* @author Rob Winch
|
||||
* @since 4.0
|
||||
*/
|
||||
public final class CsrfChannelInterceptor extends ChannelInterceptorAdapter {
|
||||
private final MessageMatcher<Object> matcher = new SimpMessageTypeMatcher(SimpMessageType.CONNECT);
|
||||
|
||||
@Override
|
||||
public Message<?> preSend(Message<?> message, MessageChannel channel) {
|
||||
if(!matcher.matches(message)) {
|
||||
return message;
|
||||
}
|
||||
|
||||
Map<String, Object> sessionAttributes = SimpMessageHeaderAccessor.getSessionAttributes(message.getHeaders());
|
||||
CsrfToken expectedToken = sessionAttributes == null ? null : (CsrfToken) sessionAttributes.get(CsrfToken.class.getName());
|
||||
|
||||
if(expectedToken == null) {
|
||||
throw new MissingCsrfTokenException(null);
|
||||
}
|
||||
|
||||
String actualTokenValue = SimpMessageHeaderAccessor.wrap(message).getFirstNativeHeader(expectedToken.getHeaderName());
|
||||
|
||||
boolean csrfCheckPassed = expectedToken.getToken().equals(actualTokenValue);
|
||||
if(csrfCheckPassed) {
|
||||
return message;
|
||||
}
|
||||
throw new InvalidCsrfTokenException(expectedToken, actualTokenValue);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
/*
|
||||
* Copyright 2002-2015 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.security.messaging.web.socket.server;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
|
||||
import org.springframework.http.server.ServerHttpRequest;
|
||||
import org.springframework.http.server.ServerHttpResponse;
|
||||
import org.springframework.http.server.ServletServerHttpRequest;
|
||||
import org.springframework.security.web.csrf.CsrfToken;
|
||||
import org.springframework.web.socket.WebSocketHandler;
|
||||
import org.springframework.web.socket.server.HandshakeInterceptor;
|
||||
|
||||
/**
|
||||
* Copies a CsrfToken from the HttpServletRequest's attributes to the WebSocket attributes. This is used as the
|
||||
* expected CsrfToken when validating connection requests to ensure only the same origin connects.
|
||||
*
|
||||
* @author Rob Winch
|
||||
* @since 4.0
|
||||
*/
|
||||
public final class CsrfTokenHandshakeInterceptor implements HandshakeInterceptor {
|
||||
|
||||
public boolean beforeHandshake(ServerHttpRequest request,
|
||||
ServerHttpResponse response, WebSocketHandler wsHandler,
|
||||
Map<String, Object> attributes) throws Exception {
|
||||
HttpServletRequest httpRequest = ((ServletServerHttpRequest)request).getServletRequest();
|
||||
CsrfToken token = (CsrfToken) httpRequest.getAttribute(CsrfToken.class.getName());
|
||||
if(token == null) {
|
||||
return true;
|
||||
}
|
||||
attributes.put(CsrfToken.class.getName(), token);
|
||||
return true;
|
||||
}
|
||||
|
||||
public void afterHandshake(ServerHttpRequest request,
|
||||
ServerHttpResponse response, WebSocketHandler wsHandler,
|
||||
Exception exception) {
|
||||
}
|
||||
}
|
|
@ -0,0 +1,154 @@
|
|||
/*
|
||||
* Copyright 2002-2015 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.security.messaging.web.csrf;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.runners.MockitoJUnitRunner;
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.messaging.MessageChannel;
|
||||
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
|
||||
import org.springframework.messaging.simp.SimpMessageType;
|
||||
import org.springframework.messaging.support.MessageBuilder;
|
||||
import org.springframework.security.web.csrf.CsrfToken;
|
||||
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
||||
import org.springframework.security.web.csrf.InvalidCsrfTokenException;
|
||||
import org.springframework.security.web.csrf.MissingCsrfTokenException;
|
||||
|
||||
@RunWith(MockitoJUnitRunner.class)
|
||||
public class CsrfChannelInterceptorTests {
|
||||
@Mock
|
||||
MessageChannel channel;
|
||||
|
||||
SimpMessageHeaderAccessor messageHeaders;
|
||||
|
||||
CsrfToken token;
|
||||
|
||||
CsrfChannelInterceptor interceptor;
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
token = new DefaultCsrfToken("header", "param", "token");
|
||||
interceptor = new CsrfChannelInterceptor();
|
||||
|
||||
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
|
||||
messageHeaders.setNativeHeader(token.getHeaderName(), token.getToken());
|
||||
messageHeaders.setSessionAttributes(new HashMap<String,Object>());
|
||||
messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), token);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void preSendValidToken() {
|
||||
interceptor.preSend(message(), channel);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void preSendIgnoresConnectAck() {
|
||||
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
|
||||
|
||||
interceptor.preSend(message(), channel);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void preSendIgnoresDisconnect() {
|
||||
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT);
|
||||
|
||||
interceptor.preSend(message(), channel);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void preSendIgnoresDisconnectAck() {
|
||||
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT_ACK);
|
||||
|
||||
interceptor.preSend(message(), channel);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void preSendIgnoresHeartbeat() {
|
||||
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT);
|
||||
|
||||
interceptor.preSend(message(), channel);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void preSendIgnoresMessage() {
|
||||
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
|
||||
|
||||
interceptor.preSend(message(), channel);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void preSendIgnoresOther() {
|
||||
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.OTHER);
|
||||
|
||||
interceptor.preSend(message(), channel);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void preSendIgnoresSubscribe() {
|
||||
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.SUBSCRIBE);
|
||||
|
||||
interceptor.preSend(message(), channel);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void preSendIgnoresUnsubscribe() {
|
||||
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.UNSUBSCRIBE);
|
||||
|
||||
interceptor.preSend(message(), channel);
|
||||
}
|
||||
|
||||
@Test(expected = InvalidCsrfTokenException.class)
|
||||
public void preSendNoToken() {
|
||||
messageHeaders.removeNativeHeader(token.getHeaderName());
|
||||
|
||||
interceptor.preSend(message(), channel);
|
||||
}
|
||||
|
||||
@Test(expected = InvalidCsrfTokenException.class)
|
||||
public void preSendInvalidToken() {
|
||||
messageHeaders.setNativeHeader(token.getHeaderName(), token.getToken() + "invalid");
|
||||
|
||||
interceptor.preSend(message(), channel);
|
||||
}
|
||||
|
||||
@Test(expected = MissingCsrfTokenException.class)
|
||||
public void preSendMissingToken() {
|
||||
messageHeaders.getSessionAttributes().clear();
|
||||
|
||||
interceptor.preSend(message(), channel);
|
||||
}
|
||||
|
||||
@Test(expected = MissingCsrfTokenException.class)
|
||||
public void preSendMissingTokenNullSessionAttributes() {
|
||||
messageHeaders.setSessionAttributes(null);
|
||||
|
||||
interceptor.preSend(message(), channel);
|
||||
}
|
||||
|
||||
private Message<String> message() {
|
||||
Map<String, Object> headersToCopy = messageHeaders.toMap();
|
||||
return MessageBuilder
|
||||
.withPayload("hi")
|
||||
.copyHeaders(headersToCopy)
|
||||
.build();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
/*
|
||||
* Copyright 2002-2015 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
* use this file except in compliance with the License. You may obtain a copy of
|
||||
* the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations under
|
||||
* the License.
|
||||
*/
|
||||
package org.springframework.security.messaging.web.socket.server;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.Before;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.runners.MockitoJUnitRunner;
|
||||
import org.springframework.http.server.ServerHttpRequest;
|
||||
import org.springframework.http.server.ServerHttpResponse;
|
||||
import org.springframework.http.server.ServletServerHttpRequest;
|
||||
import org.springframework.mock.web.MockHttpServletRequest;
|
||||
import org.springframework.security.web.csrf.CsrfToken;
|
||||
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
||||
import org.springframework.web.socket.WebSocketHandler;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.fest.assertions.Assertions.assertThat;
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @author Rob Winch
|
||||
*/
|
||||
@RunWith(MockitoJUnitRunner.class)
|
||||
public class CsrfTokenHandshakeInterceptorTests {
|
||||
@Mock
|
||||
WebSocketHandler wsHandler;
|
||||
@Mock
|
||||
ServerHttpResponse response;
|
||||
|
||||
Map<String, Object> attributes;
|
||||
|
||||
ServerHttpRequest request;
|
||||
|
||||
MockHttpServletRequest httpRequest;
|
||||
|
||||
CsrfTokenHandshakeInterceptor interceptor;
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
httpRequest = new MockHttpServletRequest();
|
||||
attributes = new HashMap<String,Object>();
|
||||
request = new ServletServerHttpRequest(httpRequest);
|
||||
|
||||
interceptor = new CsrfTokenHandshakeInterceptor();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void beforeHandshakeNoAttribute() throws Exception {
|
||||
interceptor.beforeHandshake(request, response, wsHandler, attributes);
|
||||
|
||||
assertThat(attributes).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void beforeHandshake() throws Exception {
|
||||
CsrfToken token = new DefaultCsrfToken("header", "param", "token");
|
||||
httpRequest.setAttribute(CsrfToken.class.getName(), token);
|
||||
|
||||
interceptor.beforeHandshake(request, response, wsHandler, attributes);
|
||||
|
||||
assertThat(attributes.keySet()).containsOnly(CsrfToken.class.getName());
|
||||
assertThat(attributes.values()).containsOnly(token);
|
||||
}
|
||||
|
||||
}
|
|
@ -55,6 +55,13 @@ import org.springframework.web.filter.OncePerRequestFilter;
|
|||
* @since 3.2
|
||||
*/
|
||||
public final class CsrfFilter extends OncePerRequestFilter {
|
||||
/**
|
||||
* The default {@link RequestMatcher} that indicates if CSRF protection is
|
||||
* required or not. The default is to ignore GET, HEAD, TRACE, OPTIONS and
|
||||
* process all other requests.
|
||||
*/
|
||||
public static final RequestMatcher DEFAULT_CSRF_MATCHER = new DefaultRequiresCsrfMatcher();
|
||||
|
||||
private final Log logger = LogFactory.getLog(getClass());
|
||||
private final CsrfTokenRepository tokenRepository;
|
||||
private RequestMatcher requireCsrfProtectionMatcher = new DefaultRequiresCsrfMatcher();
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
/*
|
||||
* Copyright 2002-2015 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.security.web.csrf;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
/**
|
||||
*
|
||||
* @author Rob Winch
|
||||
*
|
||||
*/
|
||||
public class MissingCsrfTokenExceptionTests {
|
||||
|
||||
// CsrfChannelInterceptor requires this to work
|
||||
@Test
|
||||
public void nullExpectedTokenDoesNotFail() {
|
||||
new MissingCsrfTokenException(null);
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue