diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java index 47f906fd2a..3a27c23a0e 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java @@ -15,9 +15,14 @@ */ package org.springframework.security.config.annotation.web.configurers; +import java.util.ArrayList; import java.util.LinkedHashMap; +import java.util.List; + +import javax.servlet.http.HttpServletRequest; import org.springframework.security.access.AccessDeniedException; +import org.springframework.security.config.annotation.web.AbstractRequestMatcherRegistry; import org.springframework.security.config.annotation.web.HttpSecurityBuilder; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.web.access.AccessDeniedHandler; @@ -31,6 +36,9 @@ import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.security.web.csrf.MissingCsrfTokenException; import org.springframework.security.web.session.InvalidSessionAccessDeniedHandler; import org.springframework.security.web.session.InvalidSessionStrategy; +import org.springframework.security.web.util.matcher.AndRequestMatcher; +import org.springframework.security.web.util.matcher.NegatedRequestMatcher; +import org.springframework.security.web.util.matcher.OrRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; @@ -66,7 +74,8 @@ import org.springframework.util.Assert; */ public final class CsrfConfigurer> extends AbstractHttpConfigurer,H> { private CsrfTokenRepository csrfTokenRepository = new HttpSessionCsrfTokenRepository(); - private RequestMatcher requireCsrfProtectionMatcher; + private RequestMatcher requireCsrfProtectionMatcher = CsrfFilter.DEFAULT_CSRF_MATCHER; + private List ignoredCsrfProtectionMatchers = new ArrayList(); /** * Creates a new instance @@ -102,10 +111,38 @@ public final class CsrfConfigurer> extends Abst return this; } + /** + *

+ * Allows specifying {@link HttpServletRequest} that should not use CSRF Protection even if they match the {@link #requireCsrfProtectionMatcher(RequestMatcher)}. + *

+ * + *

+ * The following will ensure CSRF protection ignores: + *

+ *
    + *
  • Any GET, HEAD, TRACE, OPTIONS (this is the default)
  • + *
  • We also explicitly state to ignore any request that starts with "/sockjs/"
  • + *
+ * + *
+     * http
+     *     .csrf()
+     *         .ignoringAntMatchers("/sockjs/**")
+     *         .and()
+     *     ...
+     * 
+ * + * @since 4.0 + */ + public CsrfConfigurer ignoringAntMatchers(String... antPatterns) { + return new IgnoreCsrfProtectionRegistry().antMatchers(antPatterns).and(); + } + @SuppressWarnings("unchecked") @Override public void configure(H http) throws Exception { CsrfFilter filter = new CsrfFilter(csrfTokenRepository); + RequestMatcher requireCsrfProtectionMatcher = getRequireCsrfProtectionMatcher(); if(requireCsrfProtectionMatcher != null) { filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher); } @@ -125,6 +162,18 @@ public final class CsrfConfigurer> extends Abst http.addFilter(filter); } + /** + * Gets the final {@link RequestMatcher} to use by combining the {@link #requireCsrfProtectionMatcher(RequestMatcher)} and any {@link #ignore()}. + * + * @return the {@link RequestMatcher} to use + */ + private RequestMatcher getRequireCsrfProtectionMatcher() { + if(ignoredCsrfProtectionMatchers.isEmpty()) { + return requireCsrfProtectionMatcher; + } + return new AndRequestMatcher(requireCsrfProtectionMatcher, new NegatedRequestMatcher(new OrRequestMatcher(ignoredCsrfProtectionMatchers))); + } + /** * Gets the default {@link AccessDeniedHandler} from the * {@link ExceptionHandlingConfigurer#getAccessDeniedHandler()} or create a @@ -190,4 +239,25 @@ public final class CsrfConfigurer> extends Abst handlers.put(MissingCsrfTokenException.class, invalidSessionDeniedHandler); return new DelegatingAccessDeniedHandler(handlers, defaultAccessDeniedHandler); } + + /** + * Allows registering {@link RequestMatcher} instances that should be + * ignored (even if the {@link HttpServletRequest} matches the + * {@link CsrfConfigurer#requireCsrfProtectionMatcher(RequestMatcher)}. + * + * @author Rob Winch + * @since 4.0 + */ + private class IgnoreCsrfProtectionRegistry extends AbstractRequestMatcherRegistry{ + + public CsrfConfigurer and() { + return CsrfConfigurer.this; + } + + protected IgnoreCsrfProtectionRegistry chainRequestMatchers( + List requestMatchers) { + ignoredCsrfProtectionMatchers.addAll(requestMatchers); + return this; + } + } } \ No newline at end of file diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurer.java index d50beffa1c..3335a53aca 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurer.java @@ -15,6 +15,9 @@ */ package org.springframework.security.config.annotation.web.socket; +import org.springframework.beans.factory.SmartInitializingSingleton; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.core.Ordered; import org.springframework.core.annotation.Order; @@ -28,11 +31,20 @@ import org.springframework.security.messaging.access.intercept.ChannelSecurityIn import org.springframework.security.messaging.access.intercept.MessageSecurityMetadataSource; import org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolver; import org.springframework.security.messaging.context.SecurityContextChannelInterceptor; +import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor; +import org.springframework.security.messaging.web.socket.server.CsrfTokenHandshakeInterceptor; +import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; import org.springframework.web.socket.config.annotation.AbstractWebSocketMessageBrokerConfigurer; import org.springframework.web.socket.config.annotation.StompEndpointRegistry; +import org.springframework.web.socket.server.HandshakeInterceptor; +import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; +import org.springframework.web.socket.sockjs.SockJsService; +import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler; +import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService; import java.util.ArrayList; import java.util.List; +import java.util.Map; /** * Allows configuring WebSocket Authorization. @@ -58,9 +70,12 @@ import java.util.List; * @author Rob Winch */ @Order(Ordered.HIGHEST_PRECEDENCE + 100) -public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends AbstractWebSocketMessageBrokerConfigurer { +public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends AbstractWebSocketMessageBrokerConfigurer + implements SmartInitializingSingleton { private final WebSocketMessageSecurityMetadataSourceRegistry inboundRegistry = new WebSocketMessageSecurityMetadataSourceRegistry(); + private ApplicationContext context; + public void registerStompEndpoints(StompEndpointRegistry registry) {} @Override @@ -69,16 +84,34 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends A argumentResolvers.add(new AuthenticationPrincipalArgumentResolver()); } - @Override public final void configureClientInboundChannel(ChannelRegistration registration) { ChannelSecurityInterceptor inboundChannelSecurity = inboundChannelSecurity(); + registration.setInterceptors(securityContextChannelInterceptor()); + if(sameOriginEnforced()) { + registration.setInterceptors(csrfChannelInterceptor()); + } if(inboundRegistry.containsMapping()) { - registration.setInterceptors(securityContextChannelInterceptor(),inboundChannelSecurity); + registration.setInterceptors(inboundChannelSecurity); } customizeClientInboundChannel(registration); } + /** + *

+ * Determines if a CSRF token is required for connecting. This protects against remote sites from connecting to the + * application and being able to read/write data over the connection. The default is true. + *

+ *

+ * Subclasses can override this method to disable CSRF protection + *

+ * + * @return true if a CSRF is required for connecting, else false + */ + protected boolean sameOriginEnforced() { + return true; + } + /** * Allows subclasses to customize the configuration of the {@link ChannelRegistration}. * @@ -87,6 +120,11 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends A protected void customizeClientInboundChannel(ChannelRegistration registration) { } + @Bean + public CsrfChannelInterceptor csrfChannelInterceptor() { + return new CsrfChannelInterceptor(); + } + @Bean public ChannelSecurityInterceptor inboundChannelSecurity() { ChannelSecurityInterceptor channelSecurityInterceptor = new ChannelSecurityInterceptor(inboundMessageSecurityMetadataSource()); @@ -125,4 +163,47 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends A return super.containsMapping(); } } + + @Autowired + public void setApplicationContext(ApplicationContext context) { + this.context = context; + } + + public void afterSingletonsInstantiated() { + if(!sameOriginEnforced()) { + return; + } + + String beanName = "stompWebSocketHandlerMapping"; + SimpleUrlHandlerMapping mapping = context.getBean(beanName, SimpleUrlHandlerMapping.class); + Map mappings = mapping.getHandlerMap(); + for(Object object : mappings.values()) { + if(object instanceof SockJsHttpRequestHandler) { + SockJsHttpRequestHandler sockjsHandler = (SockJsHttpRequestHandler) object; + SockJsService sockJsService = sockjsHandler.getSockJsService(); + if(!(sockJsService instanceof TransportHandlingSockJsService)) { + throw new IllegalStateException("sockJsService must be instance of TransportHandlingSockJsService got " + sockJsService); + } + + TransportHandlingSockJsService transportHandlingSockJsService = (TransportHandlingSockJsService) sockJsService; + List handshakeInterceptors = transportHandlingSockJsService.getHandshakeInterceptors(); + List interceptorsToSet = new ArrayList(handshakeInterceptors.size() + 1); + interceptorsToSet.add(new CsrfTokenHandshakeInterceptor()); + interceptorsToSet.addAll(handshakeInterceptors); + + transportHandlingSockJsService.setHandshakeInterceptors(interceptorsToSet); + } + else if(object instanceof WebSocketHttpRequestHandler) { + WebSocketHttpRequestHandler handler = (WebSocketHttpRequestHandler) object; + List handshakeInterceptors = handler.getHandshakeInterceptors(); + List interceptorsToSet = new ArrayList(handshakeInterceptors.size() + 1); + interceptorsToSet.add(new CsrfTokenHandshakeInterceptor()); + interceptorsToSet.addAll(handshakeInterceptors); + + handler.setHandshakeInterceptors(interceptorsToSet); + } else { + throw new IllegalStateException("Bean " + beanName + " is expected to contain mappings to either a SockJsHttpRequestHandler or a WebSocketHttpRequestHandler but got " + object); + } + } + } } \ No newline at end of file diff --git a/config/src/main/java/org/springframework/security/config/message/MessageSecurityBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/message/MessageSecurityBeanDefinitionParser.java index cd946b815b..b4a4625d52 100644 --- a/config/src/main/java/org/springframework/security/config/message/MessageSecurityBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/message/MessageSecurityBeanDefinitionParser.java @@ -33,6 +33,8 @@ import org.springframework.security.messaging.access.intercept.ChannelSecurityIn import org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolver; import org.springframework.security.messaging.context.SecurityContextChannelInterceptor; import org.springframework.security.messaging.util.matcher.SimpDestinationMessageMatcher; +import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor; +import org.springframework.security.messaging.web.socket.server.CsrfTokenHandshakeInterceptor; import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; import org.w3c.dom.Element; @@ -152,7 +154,8 @@ public final class MessageSecurityBeanDefinitionParser implements BeanDefinition String[] beanNames = registry.getBeanDefinitionNames(); for(String beanName : beanNames) { BeanDefinition bd = registry.getBeanDefinition(beanName); - if(bd.getBeanClassName().equals(SimpAnnotationMethodMessageHandler.class.getName())) { + String beanClassName = bd.getBeanClassName(); + if(beanClassName.equals(SimpAnnotationMethodMessageHandler.class.getName())) { PropertyValue current = bd.getPropertyValues().getPropertyValue(CUSTOM_ARG_RESOLVERS_PROP); ManagedList argResolvers = new ManagedList(); if(current != null) { @@ -161,6 +164,13 @@ public final class MessageSecurityBeanDefinitionParser implements BeanDefinition argResolvers.add(new RootBeanDefinition(AuthenticationPrincipalArgumentResolver.class)); bd.getPropertyValues().add(CUSTOM_ARG_RESOLVERS_PROP, argResolvers); } + else if(beanClassName.equals("org.springframework.web.socket.server.support.WebSocketHttpRequestHandler")) { + addCsrfTokenHandshakeInterceptor(bd); + } else if(beanClassName.equals("org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService")) { + addCsrfTokenHandshakeInterceptor(bd); + } else if(beanClassName.equals("org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService")) { + addCsrfTokenHandshakeInterceptor(bd); + } } if(!registry.containsBeanDefinition(CLIENT_INBOUND_CHANNEL_BEAN_ID)) { @@ -168,6 +178,7 @@ public final class MessageSecurityBeanDefinitionParser implements BeanDefinition } ManagedList interceptors = new ManagedList(); interceptors.add(new RootBeanDefinition(SecurityContextChannelInterceptor.class)); + interceptors.add(new RootBeanDefinition(CsrfChannelInterceptor.class)); interceptors.add(registry.getBeanDefinition(inboundSecurityInterceptorId)); BeanDefinition inboundChannel = registry.getBeanDefinition(CLIENT_INBOUND_CHANNEL_BEAN_ID); @@ -180,6 +191,14 @@ public final class MessageSecurityBeanDefinitionParser implements BeanDefinition inboundChannel.getPropertyValues().add(INTERCEPTORS_PROP, interceptors); } + private void addCsrfTokenHandshakeInterceptor(BeanDefinition bd) { + String interceptorPropertyName = "handshakeInterceptors"; + ManagedList interceptors = new ManagedList(); + interceptors.add(new RootBeanDefinition(CsrfTokenHandshakeInterceptor.class)); + interceptors.addAll((ManagedList)bd.getPropertyValues().get(interceptorPropertyName)); + bd.getPropertyValues().add(interceptorPropertyName, interceptors); + } + public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { } diff --git a/config/src/test/groovy/org/springframework/security/config/message/MessagesConfigTests.groovy b/config/src/test/groovy/org/springframework/security/config/message/MessagesConfigTests.groovy index a0c2f5d4e7..991c01457b 100644 --- a/config/src/test/groovy/org/springframework/security/config/message/MessagesConfigTests.groovy +++ b/config/src/test/groovy/org/springframework/security/config/message/MessagesConfigTests.groovy @@ -12,12 +12,24 @@ import org.springframework.http.server.ServerHttpRequest import org.springframework.http.server.ServerHttpResponse import org.springframework.messaging.handler.annotation.MessageMapping import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver +import org.springframework.messaging.simp.SimpMessageType +import org.springframework.mock.web.MockHttpServletRequest +import org.springframework.mock.web.MockHttpServletResponse import org.springframework.security.core.Authentication import org.springframework.security.core.annotation.AuthenticationPrincipal +import org.springframework.security.web.csrf.CsrfToken +import org.springframework.security.web.csrf.DefaultCsrfToken +import org.springframework.security.web.csrf.MissingCsrfTokenException import org.springframework.stereotype.Controller +import org.springframework.web.servlet.HandlerMapping +import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping import org.springframework.web.socket.WebSocketHandler import org.springframework.web.socket.server.HandshakeFailureException import org.springframework.web.socket.server.HandshakeHandler +import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor +import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler +import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler +import org.springframework.web.socket.sockjs.transport.handler.SockJsWebSocketHandler import static org.mockito.Mockito.* @@ -37,6 +49,7 @@ import org.springframework.security.core.context.SecurityContextHolder */ class MessagesConfigTests extends AbstractXmlConfigTests { Authentication messageUser = new TestingAuthenticationToken('user','pass','ROLE_USER') + boolean useSockJS = false def cleanup() { SecurityContextHolder.clearContext() @@ -93,6 +106,89 @@ class MessagesConfigTests extends AbstractXmlConfigTests { controller.authenticationPrincipal == messageUser.name } + def 'messages of type CONNECT use CsrfTokenHandshakeInterceptor'() { + setup: + def id = 'authenticationController' + bean(id,MyController) + bean('inPostProcessor',InboundExecutorPostProcessor) + messages { + 'message-interceptor'(pattern:'/**',access:'permitAll') + } + + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT) + Message message = message(headers,'/authentication') + WebSocketHttpRequestHandler handler = appContext.getBean(WebSocketHttpRequestHandler) + MockHttpServletRequest request = new MockHttpServletRequest() + String sessionAttr = "sessionAttr" + request.getSession().setAttribute(sessionAttr,"sessionValue") + + CsrfToken token = new DefaultCsrfToken("header", "param", "token") + request.setAttribute(CsrfToken.name, token) + + when: + handler.handleRequest(request , new MockHttpServletResponse()) + TestHandshakeHandler handshakeHandler = appContext.getBean(TestHandshakeHandler) + + then: 'CsrfToken is populated' + handshakeHandler.attributes.get(CsrfToken.name) == token + + and: 'Explicitly listed HandshakeInterceptor are not overridden' + handshakeHandler.attributes.get(sessionAttr) == request.getSession().getAttribute(sessionAttr) + } + + def 'messages of type CONNECT use CsrfTokenHandshakeInterceptor with SockJS'() { + setup: + useSockJS = true + def id = 'authenticationController' + bean(id,MyController) + bean('inPostProcessor',InboundExecutorPostProcessor) + messages { + 'message-interceptor'(pattern:'/**',access:'permitAll') + } + + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT) + Message message = message(headers,'/authentication') + SockJsHttpRequestHandler handler = appContext.getBean(SockJsHttpRequestHandler) + MockHttpServletRequest request = new MockHttpServletRequest() + String sessionAttr = "sessionAttr" + request.getSession().setAttribute(sessionAttr,"sessionValue") + + CsrfToken token = new DefaultCsrfToken("header", "param", "token") + request.setAttribute(CsrfToken.name, token) + + request.setMethod("GET") + request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket") + + when: + handler.handleRequest(request , new MockHttpServletResponse()) + TestHandshakeHandler handshakeHandler = appContext.getBean(TestHandshakeHandler) + + then: 'CsrfToken is populated' + handshakeHandler.attributes?.get(CsrfToken.name) == token + + and: 'Explicitly listed HandshakeInterceptor are not overridden' + handshakeHandler.attributes?.get(sessionAttr) == request.getSession().getAttribute(sessionAttr) + } + + def 'messages of type CONNECT require valid CsrfToken'() { + setup: + def id = 'authenticationController' + bean(id,MyController) + bean('inPostProcessor',InboundExecutorPostProcessor) + messages { + 'message-interceptor'(pattern:'/**',access:'permitAll') + } + + when: 'message of type CONNECTION is sent without CsrfTOken' + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT) + Message message = message(headers,'/authentication') + clientInboundChannel.send(message) + + then: 'CSRF Protection blocks the Message' + MessageDeliveryException expected = thrown() + expected.cause instanceof MissingCsrfTokenException + } + def 'messages with no id does not override customArgumentResolvers'() { setup: def id = 'authenticationController' @@ -201,6 +297,12 @@ class MessagesConfigTests extends AbstractXmlConfigTests { 'websocket:transport' {} 'websocket:stomp-endpoint'(path:'/app') { 'websocket:handshake-handler'(ref:'testHandler') {} + 'websocket:handshake-interceptors' { + 'b:bean'('class':HttpSessionHandshakeInterceptor.name) {} + } + if(useSockJS) { + 'websocket:sockjs' {} + } } 'websocket:simple-broker'(prefix:"/queue, /topic"){} } @@ -214,6 +316,11 @@ class MessagesConfigTests extends AbstractXmlConfigTests { def message(String destination) { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create() + message(headers, destination) + } + + def message(SimpMessageHeaderAccessor headers, String destination) { + messageUser = new TestingAuthenticationToken('user','pass','ROLE_USER') headers.sessionId = '123' headers.sessionAttributes = [:] headers.destination = destination @@ -257,8 +364,15 @@ class MessagesConfigTests extends AbstractXmlConfigTests { } static class TestHandshakeHandler implements HandshakeHandler { - @Override + Map attributes; + boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes) throws HandshakeFailureException { + this.attributes = attributes + if(wsHandler instanceof SockJsWebSocketHandler) { + // work around SPR-12716 + SockJsWebSocketHandler sockJs = (SockJsWebSocketHandler) wsHandler; + this.attributes = sockJs.sockJsSession.attributes + } true } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java index ed3b6866e3..f2c16d703a 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java @@ -122,6 +122,120 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { assertThat(context.getBean(MyController.class).authenticationPrincipal).isEqualTo((String) messageUser.getPrincipal()); } + @Test + public void addsAuthenticationPrincipalResolverWhenNoAuthorization() throws InterruptedException { + loadConfig(NoInboundSecurityConfig.class); + + MessageChannel messageChannel = clientInboundChannel(); + Message message = message("/permitAll/authentication"); + messageChannel.send(message); + + assertThat(context.getBean(MyController.class).authenticationPrincipal).isEqualTo((String) messageUser.getPrincipal()); + } + + @Test + public void addsCsrfProtectionWhenNoAuthorization() throws InterruptedException { + loadConfig(NoInboundSecurityConfig.class); + + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); + Message message = message(headers, "/authentication"); + MessageChannel messageChannel = clientInboundChannel(); + + try { + messageChannel.send(message); + fail("Expected Exception"); + } catch(MessageDeliveryException success) { + assertThat(success.getCause()).isInstanceOf(MissingCsrfTokenException.class); + } + } + + @Test + public void csrfProtectionForConnect() throws InterruptedException { + loadConfig(SockJsSecurityConfig.class); + + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); + Message message = message(headers, "/authentication"); + MessageChannel messageChannel = clientInboundChannel(); + + try { + messageChannel.send(message); + fail("Expected Exception"); + } catch(MessageDeliveryException success) { + assertThat(success.getCause()).isInstanceOf(MissingCsrfTokenException.class); + } + } + + @Test + public void csrfProtectionDisabledForConnect() throws InterruptedException { + loadConfig(CsrfDisabledSockJsSecurityConfig.class); + + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); + Message message = message(headers, "/permitAll/connect"); + MessageChannel messageChannel = clientInboundChannel(); + + messageChannel.send(message); + } + + @Test + public void messagesConnectUseCsrfTokenHandshakeInterceptor() throws Exception { + + loadConfig(SockJsSecurityConfig.class); + + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); + Message message = message(headers, "/authentication"); + MockHttpServletRequest request = sockjsHttpRequest("/chat"); + HttpRequestHandler handler = handler(request); + + handler.handleRequest(request, new MockHttpServletResponse()); + + assertHandshake(request); + } + + @Test + public void messagesConnectUseCsrfTokenHandshakeInterceptorMultipleMappings() throws Exception { + loadConfig(SockJsSecurityConfig.class); + + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); + Message message = message(headers, "/authentication"); + MockHttpServletRequest request = sockjsHttpRequest("/other"); + HttpRequestHandler handler = handler(request); + + handler.handleRequest(request, new MockHttpServletResponse()); + + assertHandshake(request); + } + + @Test + public void messagesConnectWebSocketUseCsrfTokenHandshakeInterceptor() throws Exception { + loadConfig(WebSocketSecurityConfig.class); + + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); + Message message = message(headers, "/authentication"); + MockHttpServletRequest request = websocketHttpRequest("/websocket"); + HttpRequestHandler handler = handler(request); + + handler.handleRequest(request, new MockHttpServletResponse()); + + assertHandshake(request); + } + + private void assertHandshake(HttpServletRequest request) { + TestHandshakeHandler handshakeHandler = context.getBean(TestHandshakeHandler.class); + assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(token); + assertThat(handshakeHandler.attributes.get(sessionAttr)).isEqualTo(request.getSession().getAttribute(sessionAttr)); + } + + private HttpRequestHandler handler(HttpServletRequest request) throws Exception { + HandlerMapping handlerMapping = context.getBean(HandlerMapping.class); + return (HttpRequestHandler) handlerMapping.getHandler(request).getHandler(); + } + + private MockHttpServletRequest websocketHttpRequest(String mapping) { + MockHttpServletRequest request = sockjsHttpRequest(mapping); + request.setRequestURI(mapping); + return request; + } + private MockHttpServletRequest sockjsHttpRequest(String mapping) { MockHttpServletRequest request = new MockHttpServletRequest(); request.setMethod("GET"); @@ -255,6 +369,75 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { } } + + + @Configuration + @EnableWebSocketMessageBroker + @Import(SyncExecutorConfig.class) + static class NoInboundSecurityConfig extends AbstractSecurityWebSocketMessageBrokerConfigurer { + + public void registerStompEndpoints(StompEndpointRegistry registry) { + registry + .addEndpoint("/other") + .withSockJS() + .setInterceptors(new HttpSessionHandshakeInterceptor()); + + registry + .addEndpoint("/chat") + .withSockJS() + .setInterceptors(new HttpSessionHandshakeInterceptor()); + } + + @Override + protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) { + } + + @Override + public void configureMessageBroker(MessageBrokerRegistry registry) { + registry.enableSimpleBroker("/queue/", "/topic/"); + registry.setApplicationDestinationPrefixes("/permitAll", "/denyAll"); + } + + @Bean + public MyController myController() { + return new MyController(); + } + } + + @Configuration + static class CsrfDisabledSockJsSecurityConfig extends SockJsSecurityConfig { + + @Override + protected boolean sameOriginEnforced() { + return false; + } + } + + @Configuration + @EnableWebSocketMessageBroker + @Import(SyncExecutorConfig.class) + static class WebSocketSecurityConfig extends AbstractSecurityWebSocketMessageBrokerConfigurer { + + public void registerStompEndpoints(StompEndpointRegistry registry) { + registry + .addEndpoint("/websocket") + .setHandshakeHandler(testHandshakeHandler()) + .addInterceptors(new HttpSessionHandshakeInterceptor()); + } + + @Override + protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) { + messages + .simpDestMatchers("/permitAll/**").permitAll() + .anyMessage().denyAll(); + } + + @Bean + public TestHandshakeHandler testHandshakeHandler() { + return new TestHandshakeHandler(); + } + } + @Configuration static class SyncExecutorConfig { @Bean diff --git a/docs/manual/src/docs/asciidoc/index.adoc b/docs/manual/src/docs/asciidoc/index.adoc index 6e0f003430..7e5b8306cb 100644 --- a/docs/manual/src/docs/asciidoc/index.adoc +++ b/docs/manual/src/docs/asciidoc/index.adoc @@ -7803,6 +7803,8 @@ The messages attribute has two different modes. If the <> is no * Ensure that any SimpAnnotationMethodMessageHandler has the AuthenticationPrincipalArgumentResolver registered as a custom argument resolver. This allows the use of `@AuthenticationPrincipal` to resolve the principal of the current `Authentication` * Ensures that the SecurityContextChannelInterceptor is automatically registered for the clientInboundChannel. This populates the SecurityContextHolder with the user that is found in the Message * Ensures that a ChannelSecurityInterceptor is registered with the clientInboundChannel. This allows authorization rules to be specified for a message. +* Ensures that a CsrfChannelInterceptor is registered with the clientInboundChannel. This ensures that only requests from the original domain are enabled. +* Ensures that a CsrfTokenHandshakeInterceptor is registered with WebSocketHttpRequestHandler, TransportHandlingSockJsService, or DefaultSockJsService. This ensures that the expected CsrfToken from the HttpServletRequest is copied into the WebSocket Session attributes. If additional control is necessary, the id can be specified and a ChannelSecurityInterceptor will be assigned to the specified id. All the wiring with Spring's messaging infrastructure can then be done manually. This is more cumbersome, but provides greater control over the configuration. diff --git a/messaging/messaging.gradle b/messaging/messaging.gradle index 7c827835ae..cc07363e13 100644 --- a/messaging/messaging.gradle +++ b/messaging/messaging.gradle @@ -9,6 +9,10 @@ dependencies { "org.springframework:spring-expression:$springVersion", "org.springframework:spring-messaging:$springVersion" + optional project(':spring-security-web'), + "org.springframework:spring-websocket:$springVersion", + "javax.servlet:javax.servlet-api:$servletApiVersion" + testCompile project(':spring-security-core').sourceSets.test.output, "commons-codec:commons-codec:$commonsCodecVersion", "org.slf4j:jcl-over-slf4j:$slf4jVersion", diff --git a/messaging/src/main/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptor.java b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptor.java new file mode 100644 index 0000000000..66c798b755 --- /dev/null +++ b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptor.java @@ -0,0 +1,62 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.messaging.web.csrf; + +import java.util.Map; + +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.SimpMessageType; +import org.springframework.messaging.support.ChannelInterceptorAdapter; +import org.springframework.security.messaging.util.matcher.MessageMatcher; +import org.springframework.security.messaging.util.matcher.SimpMessageTypeMatcher; +import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.InvalidCsrfTokenException; +import org.springframework.security.web.csrf.MissingCsrfTokenException; + +/** + * {@link ChannelInterceptorAdapter} that validates that a valid CSRF is included in the header of any + * {@link SimpMessageType#CONNECT} message. The expected {@link CsrfToken} is populated by CsrfTokenHandshakeInterceptor. + * + * @author Rob Winch + * @since 4.0 + */ +public final class CsrfChannelInterceptor extends ChannelInterceptorAdapter { + private final MessageMatcher matcher = new SimpMessageTypeMatcher(SimpMessageType.CONNECT); + + @Override + public Message preSend(Message message, MessageChannel channel) { + if(!matcher.matches(message)) { + return message; + } + + Map sessionAttributes = SimpMessageHeaderAccessor.getSessionAttributes(message.getHeaders()); + CsrfToken expectedToken = sessionAttributes == null ? null : (CsrfToken) sessionAttributes.get(CsrfToken.class.getName()); + + if(expectedToken == null) { + throw new MissingCsrfTokenException(null); + } + + String actualTokenValue = SimpMessageHeaderAccessor.wrap(message).getFirstNativeHeader(expectedToken.getHeaderName()); + + boolean csrfCheckPassed = expectedToken.getToken().equals(actualTokenValue); + if(csrfCheckPassed) { + return message; + } + throw new InvalidCsrfTokenException(expectedToken, actualTokenValue); + } +} diff --git a/messaging/src/main/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptor.java b/messaging/src/main/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptor.java new file mode 100644 index 0000000000..dad7c181ed --- /dev/null +++ b/messaging/src/main/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptor.java @@ -0,0 +1,54 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.messaging.web.socket.server; + +import java.util.Map; + +import javax.servlet.http.HttpServletRequest; + +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.server.HandshakeInterceptor; + +/** + * Copies a CsrfToken from the HttpServletRequest's attributes to the WebSocket attributes. This is used as the + * expected CsrfToken when validating connection requests to ensure only the same origin connects. + * + * @author Rob Winch + * @since 4.0 + */ +public final class CsrfTokenHandshakeInterceptor implements HandshakeInterceptor { + + public boolean beforeHandshake(ServerHttpRequest request, + ServerHttpResponse response, WebSocketHandler wsHandler, + Map attributes) throws Exception { + HttpServletRequest httpRequest = ((ServletServerHttpRequest)request).getServletRequest(); + CsrfToken token = (CsrfToken) httpRequest.getAttribute(CsrfToken.class.getName()); + if(token == null) { + return true; + } + attributes.put(CsrfToken.class.getName(), token); + return true; + } + + public void afterHandshake(ServerHttpRequest request, + ServerHttpResponse response, WebSocketHandler wsHandler, + Exception exception) { + } +} diff --git a/messaging/src/test/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptorTests.java b/messaging/src/test/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptorTests.java new file mode 100644 index 0000000000..69b9785a5e --- /dev/null +++ b/messaging/src/test/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptorTests.java @@ -0,0 +1,154 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.messaging.web.csrf; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.SimpMessageType; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.security.web.csrf.InvalidCsrfTokenException; +import org.springframework.security.web.csrf.MissingCsrfTokenException; + +@RunWith(MockitoJUnitRunner.class) +public class CsrfChannelInterceptorTests { + @Mock + MessageChannel channel; + + SimpMessageHeaderAccessor messageHeaders; + + CsrfToken token; + + CsrfChannelInterceptor interceptor; + + @Before + public void setup() { + token = new DefaultCsrfToken("header", "param", "token"); + interceptor = new CsrfChannelInterceptor(); + + messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); + messageHeaders.setNativeHeader(token.getHeaderName(), token.getToken()); + messageHeaders.setSessionAttributes(new HashMap()); + messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), token); + } + + @Test + public void preSendValidToken() { + interceptor.preSend(message(), channel); + } + + @Test + public void preSendIgnoresConnectAck() { + messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK); + + interceptor.preSend(message(), channel); + } + + @Test + public void preSendIgnoresDisconnect() { + messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT); + + interceptor.preSend(message(), channel); + } + + @Test + public void preSendIgnoresDisconnectAck() { + messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT_ACK); + + interceptor.preSend(message(), channel); + } + + @Test + public void preSendIgnoresHeartbeat() { + messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT); + + interceptor.preSend(message(), channel); + } + + @Test + public void preSendIgnoresMessage() { + messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); + + interceptor.preSend(message(), channel); + } + + @Test + public void preSendIgnoresOther() { + messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.OTHER); + + interceptor.preSend(message(), channel); + } + + @Test + public void preSendIgnoresSubscribe() { + messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.SUBSCRIBE); + + interceptor.preSend(message(), channel); + } + + @Test + public void preSendIgnoresUnsubscribe() { + messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.UNSUBSCRIBE); + + interceptor.preSend(message(), channel); + } + + @Test(expected = InvalidCsrfTokenException.class) + public void preSendNoToken() { + messageHeaders.removeNativeHeader(token.getHeaderName()); + + interceptor.preSend(message(), channel); + } + + @Test(expected = InvalidCsrfTokenException.class) + public void preSendInvalidToken() { + messageHeaders.setNativeHeader(token.getHeaderName(), token.getToken() + "invalid"); + + interceptor.preSend(message(), channel); + } + + @Test(expected = MissingCsrfTokenException.class) + public void preSendMissingToken() { + messageHeaders.getSessionAttributes().clear(); + + interceptor.preSend(message(), channel); + } + + @Test(expected = MissingCsrfTokenException.class) + public void preSendMissingTokenNullSessionAttributes() { + messageHeaders.setSessionAttributes(null); + + interceptor.preSend(message(), channel); + } + + private Message message() { + Map headersToCopy = messageHeaders.toMap(); + return MessageBuilder + .withPayload("hi") + .copyHeaders(headersToCopy) + .build(); + } +} diff --git a/messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java b/messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java new file mode 100644 index 0000000000..d7f29c4dea --- /dev/null +++ b/messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java @@ -0,0 +1,83 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.springframework.security.messaging.web.socket.server; + +import org.junit.Test; +import org.junit.Before; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.web.socket.WebSocketHandler; + +import java.util.HashMap; +import java.util.Map; + +import static org.fest.assertions.Assertions.assertThat; + + +/** + * + * @author Rob Winch + */ +@RunWith(MockitoJUnitRunner.class) +public class CsrfTokenHandshakeInterceptorTests { + @Mock + WebSocketHandler wsHandler; + @Mock + ServerHttpResponse response; + + Map attributes; + + ServerHttpRequest request; + + MockHttpServletRequest httpRequest; + + CsrfTokenHandshakeInterceptor interceptor; + + @Before + public void setup() { + httpRequest = new MockHttpServletRequest(); + attributes = new HashMap(); + request = new ServletServerHttpRequest(httpRequest); + + interceptor = new CsrfTokenHandshakeInterceptor(); + } + + @Test + public void beforeHandshakeNoAttribute() throws Exception { + interceptor.beforeHandshake(request, response, wsHandler, attributes); + + assertThat(attributes).isEmpty(); + } + + @Test + public void beforeHandshake() throws Exception { + CsrfToken token = new DefaultCsrfToken("header", "param", "token"); + httpRequest.setAttribute(CsrfToken.class.getName(), token); + + interceptor.beforeHandshake(request, response, wsHandler, attributes); + + assertThat(attributes.keySet()).containsOnly(CsrfToken.class.getName()); + assertThat(attributes.values()).containsOnly(token); + } + +} \ No newline at end of file diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java index 683d20d4ab..83ea69308a 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java @@ -55,6 +55,13 @@ import org.springframework.web.filter.OncePerRequestFilter; * @since 3.2 */ public final class CsrfFilter extends OncePerRequestFilter { + /** + * The default {@link RequestMatcher} that indicates if CSRF protection is + * required or not. The default is to ignore GET, HEAD, TRACE, OPTIONS and + * process all other requests. + */ + public static final RequestMatcher DEFAULT_CSRF_MATCHER = new DefaultRequiresCsrfMatcher(); + private final Log logger = LogFactory.getLog(getClass()); private final CsrfTokenRepository tokenRepository; private RequestMatcher requireCsrfProtectionMatcher = new DefaultRequiresCsrfMatcher(); diff --git a/web/src/test/java/org/springframework/security/web/csrf/MissingCsrfTokenExceptionTests.java b/web/src/test/java/org/springframework/security/web/csrf/MissingCsrfTokenExceptionTests.java new file mode 100644 index 0000000000..d7f276efd5 --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/csrf/MissingCsrfTokenExceptionTests.java @@ -0,0 +1,33 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.web.csrf; + +import org.junit.Test; + +/** + * + * @author Rob Winch + * + */ +public class MissingCsrfTokenExceptionTests { + + // CsrfChannelInterceptor requires this to work + @Test + public void nullExpectedTokenDoesNotFail() { + new MissingCsrfTokenException(null); + } + +}