SEC-2864: Default Spring Security WebSocket PathMatcher XML Namespace

This commit is contained in:
Rob Winch 2015-03-25 16:32:03 -05:00
parent db531d9100
commit 7b25b3e40d
2 changed files with 542 additions and 455 deletions

View File

@ -38,11 +38,14 @@ import org.springframework.security.messaging.util.matcher.SimpMessageTypeMatche
import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor; import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
import org.springframework.security.messaging.web.socket.server.CsrfTokenHandshakeInterceptor; import org.springframework.security.messaging.web.socket.server.CsrfTokenHandshakeInterceptor;
import org.springframework.util.AntPathMatcher; import org.springframework.util.AntPathMatcher;
import org.springframework.util.PathMatcher;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.util.xml.DomUtils; import org.springframework.util.xml.DomUtils;
import org.w3c.dom.Element; import org.w3c.dom.Element;
import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.Map;
/** /**
* Parses Spring Security's websocket namespace support. A simple example is: * Parses Spring Security's websocket namespace support. A simple example is:
@ -84,9 +87,6 @@ import java.util.List;
*/ */
public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements
BeanDefinitionParser { BeanDefinitionParser {
private static final Log logger = LogFactory
.getLog(WebSocketMessageBrokerSecurityBeanDefinitionParser.class);
private static final String ID_ATTR = "id"; private static final String ID_ATTR = "id";
private static final String DISABLED_ATTR = "same-origin-disabled"; private static final String DISABLED_ATTR = "same-origin-disabled";
@ -97,6 +97,8 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements
private static final String TYPE_ATTR = "type"; private static final String TYPE_ATTR = "type";
private static final String PATH_MATCHER_BEAN_NAME = "springSecurityMessagePathMatcher";
/** /**
* @param element * @param element
* @param parserContext * @param parserContext
@ -149,6 +151,10 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements
if (StringUtils.hasText(id)) { if (StringUtils.hasText(id)) {
registry.registerAlias(inSecurityInterceptorName, id); registry.registerAlias(inSecurityInterceptorName, id);
if(!registry.containsBeanDefinition(PATH_MATCHER_BEAN_NAME)) {
registry.registerBeanDefinition(PATH_MATCHER_BEAN_NAME, new RootBeanDefinition(AntPathMatcher.class));
}
} }
else { else {
BeanDefinitionBuilder mspp = BeanDefinitionBuilder BeanDefinitionBuilder mspp = BeanDefinitionBuilder
@ -190,16 +196,18 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements
interceptMessage); interceptMessage);
} }
} }
BeanDefinitionBuilder matcher = BeanDefinitionBuilder BeanDefinitionBuilder matcher = BeanDefinitionBuilder
.rootBeanDefinition(SimpDestinationMessageMatcher.class); .rootBeanDefinition(SimpDestinationMessageMatcher.class);
matcher.setFactoryMethod(factoryName); matcher.setFactoryMethod(factoryName);
matcher.addConstructorArgValue(matcherPattern); matcher.addConstructorArgValue(matcherPattern);
matcher.addConstructorArgValue(new RootBeanDefinition(AntPathMatcher.class)); matcher.addConstructorArgValue(new RuntimeBeanReference("springSecurityMessagePathMatcher"));
return matcher.getBeanDefinition(); return matcher.getBeanDefinition();
} }
static class MessageSecurityPostProcessor implements static class MessageSecurityPostProcessor implements
BeanDefinitionRegistryPostProcessor { BeanDefinitionRegistryPostProcessor {
private static final String CLIENT_INBOUND_CHANNEL_BEAN_ID = "clientInboundChannel"; private static final String CLIENT_INBOUND_CHANNEL_BEAN_ID = "clientInboundChannel";
private static final String INTERCEPTORS_PROP = "interceptors"; private static final String INTERCEPTORS_PROP = "interceptors";
@ -233,6 +241,14 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements
argResolvers.add(new RootBeanDefinition( argResolvers.add(new RootBeanDefinition(
AuthenticationPrincipalArgumentResolver.class)); AuthenticationPrincipalArgumentResolver.class));
bd.getPropertyValues().add(CUSTOM_ARG_RESOLVERS_PROP, argResolvers); bd.getPropertyValues().add(CUSTOM_ARG_RESOLVERS_PROP, argResolvers);
if(!registry.containsBeanDefinition(PATH_MATCHER_BEAN_NAME)) {
PropertyValue pathMatcherProp = bd.getPropertyValues().getPropertyValue("pathMatcher");
Object pathMatcher = pathMatcherProp == null ? null : pathMatcherProp.getValue();
if(pathMatcher instanceof BeanReference) {
registry.registerAlias(((BeanReference) pathMatcher).getBeanName(), PATH_MATCHER_BEAN_NAME);
}
}
} }
else if (beanClassName else if (beanClassName
.equals("org.springframework.web.socket.server.support.WebSocketHttpRequestHandler")) { .equals("org.springframework.web.socket.server.support.WebSocketHttpRequestHandler")) {
@ -270,6 +286,10 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements
} }
inboundChannel.getPropertyValues().add(INTERCEPTORS_PROP, interceptors); inboundChannel.getPropertyValues().add(INTERCEPTORS_PROP, interceptors);
if(!registry.containsBeanDefinition(PATH_MATCHER_BEAN_NAME)) {
registry.registerBeanDefinition(PATH_MATCHER_BEAN_NAME, new RootBeanDefinition(AntPathMatcher.class));
}
} }
private void addCsrfTokenHandshakeInterceptor(BeanDefinition bd) { private void addCsrfTokenHandshakeInterceptor(BeanDefinition bd) {
@ -289,4 +309,41 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements
} }
} }
static class DelegatingPathMatcher implements PathMatcher {
private PathMatcher delegate = new AntPathMatcher();
public boolean isPattern(String path) {
return delegate.isPattern(path);
}
public boolean match(String pattern, String path) {
return delegate.match(pattern, path);
}
public boolean matchStart(String pattern, String path) {
return delegate.matchStart(pattern, path);
}
public String extractPathWithinPattern(String pattern, String path) {
return delegate.extractPathWithinPattern(pattern, path);
}
public Map<String, String> extractUriTemplateVariables(String pattern, String path) {
return delegate.extractUriTemplateVariables(pattern, path);
}
public Comparator<String> getPatternComparator(String path) {
return delegate.getPatternComparator(path);
}
public String combine(String pattern1, String pattern2) {
return delegate.combine(pattern1, pattern2);
}
void setPathMatcher(PathMatcher pathMatcher) {
this.delegate = pathMatcher;
}
}
} }

View File

@ -1,5 +1,7 @@
package org.springframework.security.config.websocket package org.springframework.security.config.websocket
import static org.mockito.Mockito.*
import org.springframework.beans.BeansException import org.springframework.beans.BeansException
import org.springframework.beans.factory.config.BeanDefinition import org.springframework.beans.factory.config.BeanDefinition
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory import org.springframework.beans.factory.config.ConfigurableListableBeanFactory
@ -11,21 +13,29 @@ import org.springframework.core.MethodParameter
import org.springframework.core.task.SyncTaskExecutor import org.springframework.core.task.SyncTaskExecutor
import org.springframework.http.server.ServerHttpRequest import org.springframework.http.server.ServerHttpRequest
import org.springframework.http.server.ServerHttpResponse import org.springframework.http.server.ServerHttpResponse
import org.springframework.messaging.Message
import org.springframework.messaging.MessageDeliveryException
import org.springframework.messaging.handler.annotation.MessageMapping import org.springframework.messaging.handler.annotation.MessageMapping
import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver
import org.springframework.messaging.simp.SimpMessageHeaderAccessor
import org.springframework.messaging.simp.SimpMessageType import org.springframework.messaging.simp.SimpMessageType
import org.springframework.messaging.simp.annotation.support.SimpAnnotationMethodMessageHandler
import org.springframework.messaging.support.ChannelInterceptor
import org.springframework.messaging.support.GenericMessage
import org.springframework.mock.web.MockHttpServletRequest import org.springframework.mock.web.MockHttpServletRequest
import org.springframework.mock.web.MockHttpServletResponse import org.springframework.mock.web.MockHttpServletResponse
import org.springframework.security.access.AccessDeniedException
import org.springframework.security.authentication.TestingAuthenticationToken
import org.springframework.security.config.AbstractXmlConfigTests
import org.springframework.security.core.Authentication import org.springframework.security.core.Authentication
import org.springframework.security.core.annotation.AuthenticationPrincipal import org.springframework.security.core.annotation.AuthenticationPrincipal
import org.springframework.security.messaging.util.matcher.SimpMessageTypeMatcher import org.springframework.security.core.context.SecurityContextHolder
import org.springframework.security.web.csrf.CsrfToken import org.springframework.security.web.csrf.CsrfToken
import org.springframework.security.web.csrf.DefaultCsrfToken import org.springframework.security.web.csrf.DefaultCsrfToken
import org.springframework.security.web.csrf.InvalidCsrfTokenException import org.springframework.security.web.csrf.InvalidCsrfTokenException
import org.springframework.security.web.csrf.MissingCsrfTokenException
import org.springframework.stereotype.Controller import org.springframework.stereotype.Controller
import org.springframework.util.AntPathMatcher
import org.springframework.web.servlet.HandlerMapping import org.springframework.web.servlet.HandlerMapping
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping
import org.springframework.web.socket.WebSocketHandler import org.springframework.web.socket.WebSocketHandler
import org.springframework.web.socket.server.HandshakeFailureException import org.springframework.web.socket.server.HandshakeFailureException
import org.springframework.web.socket.server.HandshakeHandler import org.springframework.web.socket.server.HandshakeHandler
@ -33,459 +43,479 @@ import org.springframework.web.socket.server.support.HttpSessionHandshakeInterce
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler
import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler
import org.springframework.web.socket.sockjs.transport.handler.SockJsWebSocketHandler import org.springframework.web.socket.sockjs.transport.handler.SockJsWebSocketHandler
import spock.lang.Unroll import spock.lang.Unroll
import static org.mockito.Mockito.*
import org.springframework.messaging.Message
import org.springframework.messaging.MessageDeliveryException
import org.springframework.messaging.simp.SimpMessageHeaderAccessor
import org.springframework.messaging.support.ChannelInterceptor
import org.springframework.messaging.support.GenericMessage
import org.springframework.security.access.AccessDeniedException
import org.springframework.security.authentication.TestingAuthenticationToken
import org.springframework.security.config.AbstractXmlConfigTests
import org.springframework.security.core.context.SecurityContextHolder
/** /**
* *
* @author Rob Winch * @author Rob Winch
*/ */
class WebSocketMessageBrokerConfigTests extends AbstractXmlConfigTests { class WebSocketMessageBrokerConfigTests extends AbstractXmlConfigTests {
Authentication messageUser = new TestingAuthenticationToken('user','pass','ROLE_USER') Authentication messageUser = new TestingAuthenticationToken('user','pass','ROLE_USER')
boolean useSockJS = false boolean useSockJS = false
CsrfToken csrfToken = new DefaultCsrfToken('headerName', 'paramName', 'token') CsrfToken csrfToken = new DefaultCsrfToken('headerName', 'paramName', 'token')
def cleanup() { def cleanup() {
SecurityContextHolder.clearContext() SecurityContextHolder.clearContext()
} }
def 'websocket with no id automatically integrates with clientInboundChannel'() { def 'websocket with no id automatically integrates with clientInboundChannel'() {
setup: setup:
websocket { websocket {
'intercept-message'(pattern:'/permitAll',access:'permitAll') 'intercept-message'(pattern:'/permitAll',access:'permitAll')
'intercept-message'(pattern:'/denyAll',access:'denyAll') 'intercept-message'(pattern:'/denyAll',access:'denyAll')
} }
when: 'message is sent to the denyAll endpoint' when: 'message is sent to the denyAll endpoint'
clientInboundChannel.send(message('/denyAll')) clientInboundChannel.send(message('/denyAll'))
then: 'access is denied to the denyAll endpoint' then: 'access is denied to the denyAll endpoint'
def e = thrown(MessageDeliveryException) def e = thrown(MessageDeliveryException)
e.cause instanceof AccessDeniedException e.cause instanceof AccessDeniedException
and: 'access is granted to the permitAll endpoint' and: 'access is granted to the permitAll endpoint'
clientInboundChannel.send(message('/permitAll')) clientInboundChannel.send(message('/permitAll'))
} }
def 'anonymous authentication supported'() { def 'anonymous authentication supported'() {
setup: setup:
websocket { websocket {
'intercept-message'(pattern:'/permitAll',access:'permitAll') 'intercept-message'(pattern:'/permitAll',access:'permitAll')
'intercept-message'(pattern:'/denyAll',access:'denyAll') 'intercept-message'(pattern:'/denyAll',access:'denyAll')
} }
messageUser = null messageUser = null
when: 'message is sent to the permitAll endpoint with no user' when: 'message is sent to the permitAll endpoint with no user'
clientInboundChannel.send(message('/permitAll')) clientInboundChannel.send(message('/permitAll'))
then: 'access is granted' then: 'access is granted'
noExceptionThrown() noExceptionThrown()
} }
@Unroll @Unroll
def "message type - #type"(SimpMessageType type) { def "message type - #type"(SimpMessageType type) {
setup: setup:
websocket { websocket {
'intercept-message'('type': type.toString(), access:'permitAll') 'intercept-message'('type': type.toString(), access:'permitAll')
'intercept-message'(pattern:'/**', access:'denyAll') 'intercept-message'(pattern:'/**', access:'denyAll')
} }
messageUser = null messageUser = null
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(type) SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(type)
if(SimpMessageType.CONNECT == type) { if(SimpMessageType.CONNECT == type) {
headers.setNativeHeader(csrfToken.headerName, csrfToken.token) headers.setNativeHeader(csrfToken.headerName, csrfToken.token)
} }
Message message = message(headers, '/permitAll') Message message = message(headers, '/permitAll')
when: 'message is sent to the permitAll endpoint with no user' when: 'message is sent to the permitAll endpoint with no user'
clientInboundChannel.send(message) clientInboundChannel.send(message)
then: 'access is granted' then: 'access is granted'
noExceptionThrown() noExceptionThrown()
where: where:
type << SimpMessageType.values() type << SimpMessageType.values()
} }
@Unroll @Unroll
def "pattern and message type - #type"(SimpMessageType type) { def "pattern and message type - #type"(SimpMessageType type) {
setup: setup:
websocket { websocket {
'intercept-message'(pattern: '/permitAll', 'type': type.toString(), access:'permitAll') 'intercept-message'(pattern: '/permitAll', 'type': type.toString(), access:'permitAll')
'intercept-message'(pattern:'/**', access:'denyAll') 'intercept-message'(pattern:'/**', access:'denyAll')
} }
when: 'message is sent to the permitAll endpoint with no user' when: 'message is sent to the permitAll endpoint with no user'
clientInboundChannel.send(message('/permitAll', type)) clientInboundChannel.send(message('/permitAll', type))
then: 'access is granted' then: 'access is granted'
noExceptionThrown() noExceptionThrown()
when: 'message sent to other message type' when: 'message sent to other message type'
clientInboundChannel.send(message('/permitAll', SimpMessageType.UNSUBSCRIBE)) clientInboundChannel.send(message('/permitAll', SimpMessageType.UNSUBSCRIBE))
then: 'does not match' then: 'does not match'
MessageDeliveryException e = thrown() MessageDeliveryException e = thrown()
e.cause instanceof AccessDeniedException e.cause instanceof AccessDeniedException
when: 'message is sent to other pattern' when: 'message is sent to other pattern'
clientInboundChannel.send(message('/other', type)) clientInboundChannel.send(message('/other', type))
then: 'does not match' then: 'does not match'
MessageDeliveryException eOther = thrown() MessageDeliveryException eOther = thrown()
eOther.cause instanceof AccessDeniedException eOther.cause instanceof AccessDeniedException
where: where:
type << [SimpMessageType.MESSAGE, SimpMessageType.SUBSCRIBE] type << [SimpMessageType.MESSAGE, SimpMessageType.SUBSCRIBE]
} }
@Unroll @Unroll
def "intercept-message with invalid type and pattern - #type"(SimpMessageType type) { def "intercept-message with invalid type and pattern - #type"(SimpMessageType type) {
when: when:
websocket { websocket {
'intercept-message'(pattern : '/**', 'type': type.toString(), access:'permitAll') 'intercept-message'(pattern : '/**', 'type': type.toString(), access:'permitAll')
} }
then: then:
thrown(BeanDefinitionParsingException) thrown(BeanDefinitionParsingException)
where: where:
type << [SimpMessageType.CONNECT, SimpMessageType.CONNECT_ACK, SimpMessageType.DISCONNECT, SimpMessageType.DISCONNECT_ACK, SimpMessageType.HEARTBEAT, SimpMessageType.OTHER, SimpMessageType.UNSUBSCRIBE ] type << [SimpMessageType.CONNECT, SimpMessageType.CONNECT_ACK, SimpMessageType.DISCONNECT, SimpMessageType.DISCONNECT_ACK, SimpMessageType.HEARTBEAT, SimpMessageType.OTHER, SimpMessageType.UNSUBSCRIBE ]
} }
def 'messages with no id automatically adds Authentication argument resolver'() { def 'messages with no id automatically adds Authentication argument resolver'() {
setup: setup:
def id = 'authenticationController' def id = 'authenticationController'
bean(id,MyController) bean(id,MyController)
bean('inPostProcessor',InboundExecutorPostProcessor) bean('inPostProcessor',InboundExecutorPostProcessor)
websocket { websocket {
'intercept-message'(pattern:'/**',access:'permitAll') 'intercept-message'(pattern:'/**',access:'permitAll')
} }
when: 'message is sent to the authentication endpoint' when: 'message is sent to the authentication endpoint'
clientInboundChannel.send(message('/authentication')) clientInboundChannel.send(message('/authentication'))
then: 'the AuthenticationPrincipal is resolved' then: 'the AuthenticationPrincipal is resolved'
def controller = appContext.getBean(id) def controller = appContext.getBean(id)
controller.authenticationPrincipal == messageUser.name controller.authenticationPrincipal == messageUser.name
} }
def 'messages of type CONNECT use CsrfTokenHandshakeInterceptor'() { def 'messages of type CONNECT use CsrfTokenHandshakeInterceptor'() {
setup: setup:
def id = 'authenticationController' def id = 'authenticationController'
bean(id,MyController) bean(id,MyController)
bean('inPostProcessor',InboundExecutorPostProcessor) bean('inPostProcessor',InboundExecutorPostProcessor)
websocket { websocket {
'intercept-message'(pattern:'/**',access:'permitAll') 'intercept-message'(pattern:'/**',access:'permitAll')
} }
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT) SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT)
Message<?> message = message(headers,'/authentication') Message<?> message = message(headers,'/authentication')
WebSocketHttpRequestHandler handler = appContext.getBean(WebSocketHttpRequestHandler) WebSocketHttpRequestHandler handler = appContext.getBean(WebSocketHttpRequestHandler)
MockHttpServletRequest request = new MockHttpServletRequest() MockHttpServletRequest request = new MockHttpServletRequest()
String sessionAttr = "sessionAttr" String sessionAttr = "sessionAttr"
request.getSession().setAttribute(sessionAttr,"sessionValue") request.getSession().setAttribute(sessionAttr,"sessionValue")
CsrfToken token = new DefaultCsrfToken("header", "param", "token") CsrfToken token = new DefaultCsrfToken("header", "param", "token")
request.setAttribute(CsrfToken.name, token) request.setAttribute(CsrfToken.name, token)
when: when:
handler.handleRequest(request , new MockHttpServletResponse()) handler.handleRequest(request , new MockHttpServletResponse())
TestHandshakeHandler handshakeHandler = appContext.getBean(TestHandshakeHandler) TestHandshakeHandler handshakeHandler = appContext.getBean(TestHandshakeHandler)
then: 'CsrfToken is populated' then: 'CsrfToken is populated'
handshakeHandler.attributes.get(CsrfToken.name) == token handshakeHandler.attributes.get(CsrfToken.name) == token
and: 'Explicitly listed HandshakeInterceptor are not overridden' and: 'Explicitly listed HandshakeInterceptor are not overridden'
handshakeHandler.attributes.get(sessionAttr) == request.getSession().getAttribute(sessionAttr) handshakeHandler.attributes.get(sessionAttr) == request.getSession().getAttribute(sessionAttr)
} }
def 'messages of type CONNECT use CsrfTokenHandshakeInterceptor with SockJS'() { def 'messages of type CONNECT use CsrfTokenHandshakeInterceptor with SockJS'() {
setup: setup:
useSockJS = true useSockJS = true
def id = 'authenticationController' def id = 'authenticationController'
bean(id,MyController) bean(id,MyController)
bean('inPostProcessor',InboundExecutorPostProcessor) bean('inPostProcessor',InboundExecutorPostProcessor)
websocket { websocket {
'intercept-message'(pattern:'/**',access:'permitAll') 'intercept-message'(pattern:'/**',access:'permitAll')
} }
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT) SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT)
Message<?> message = message(headers,'/authentication') Message<?> message = message(headers,'/authentication')
SockJsHttpRequestHandler handler = appContext.getBean(SockJsHttpRequestHandler) SockJsHttpRequestHandler handler = appContext.getBean(SockJsHttpRequestHandler)
MockHttpServletRequest request = new MockHttpServletRequest() MockHttpServletRequest request = new MockHttpServletRequest()
String sessionAttr = "sessionAttr" String sessionAttr = "sessionAttr"
request.getSession().setAttribute(sessionAttr,"sessionValue") request.getSession().setAttribute(sessionAttr,"sessionValue")
CsrfToken token = new DefaultCsrfToken("header", "param", "token") CsrfToken token = new DefaultCsrfToken("header", "param", "token")
request.setAttribute(CsrfToken.name, token) request.setAttribute(CsrfToken.name, token)
request.setMethod("GET") request.setMethod("GET")
request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket") request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket")
when: when:
handler.handleRequest(request , new MockHttpServletResponse()) handler.handleRequest(request , new MockHttpServletResponse())
TestHandshakeHandler handshakeHandler = appContext.getBean(TestHandshakeHandler) TestHandshakeHandler handshakeHandler = appContext.getBean(TestHandshakeHandler)
then: 'CsrfToken is populated' then: 'CsrfToken is populated'
handshakeHandler.attributes?.get(CsrfToken.name) == token handshakeHandler.attributes?.get(CsrfToken.name) == token
and: 'Explicitly listed HandshakeInterceptor are not overridden' and: 'Explicitly listed HandshakeInterceptor are not overridden'
handshakeHandler.attributes?.get(sessionAttr) == request.getSession().getAttribute(sessionAttr) handshakeHandler.attributes?.get(sessionAttr) == request.getSession().getAttribute(sessionAttr)
} }
def 'messages of type CONNECT require valid CsrfToken'() { def 'messages of type CONNECT require valid CsrfToken'() {
setup: setup:
def id = 'authenticationController' def id = 'authenticationController'
bean(id,MyController) bean(id,MyController)
bean('inPostProcessor',InboundExecutorPostProcessor) bean('inPostProcessor',InboundExecutorPostProcessor)
websocket { websocket {
'intercept-message'(pattern:'/**',access:'permitAll') 'intercept-message'(pattern:'/**',access:'permitAll')
} }
when: 'websocket of type CONNECTION is sent without CsrfTOken' when: 'websocket of type CONNECTION is sent without CsrfTOken'
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT) SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT)
Message<?> message = message(headers,'/authentication') Message<?> message = message(headers,'/authentication')
clientInboundChannel.send(message) clientInboundChannel.send(message)
then: 'CSRF Protection blocks the Message' then: 'CSRF Protection blocks the Message'
MessageDeliveryException expected = thrown() MessageDeliveryException expected = thrown()
expected.cause instanceof InvalidCsrfTokenException expected.cause instanceof InvalidCsrfTokenException
} }
def 'messages of type CONNECT disabled valid CsrfToken'() { def 'messages of type CONNECT disabled valid CsrfToken'() {
setup: setup:
def id = 'authenticationController' def id = 'authenticationController'
bean(id,MyController) bean(id,MyController)
bean('inPostProcessor',InboundExecutorPostProcessor) bean('inPostProcessor',InboundExecutorPostProcessor)
websocket('same-origin-disabled':true) { websocket('same-origin-disabled':true) {
'intercept-message'(pattern:'/**',access:'permitAll') 'intercept-message'(pattern:'/**',access:'permitAll')
} }
when: 'websocket of type CONNECTION is sent without CsrfTOken' when: 'websocket of type CONNECTION is sent without CsrfTOken'
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT) SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT)
Message<?> message = message(headers,'/authentication') Message<?> message = message(headers,'/authentication')
clientInboundChannel.send(message) clientInboundChannel.send(message)
then: 'CSRF Protection blocks the Message' then: 'CSRF Protection blocks the Message'
noExceptionThrown() noExceptionThrown()
} }
def 'websocket with no id does not override customArgumentResolvers'() { def 'websocket with no id does not override customArgumentResolvers'() {
setup: setup:
def id = 'authenticationController' def id = 'authenticationController'
bean(id,MyController) bean(id,MyController)
bean('inPostProcessor',InboundExecutorPostProcessor) bean('inPostProcessor',InboundExecutorPostProcessor)
bean('mcar', MyCustomArgumentResolver) bean('mcar', MyCustomArgumentResolver)
xml.'websocket:message-broker' { xml.'websocket:message-broker' {
'websocket:transport' {} 'websocket:transport' {}
'websocket:stomp-endpoint'(path:'/app') { 'websocket:stomp-endpoint'(path:'/app') {
'websocket:handshake-handler'(ref:'testHandler') {} 'websocket:handshake-handler'(ref:'testHandler') {}
} }
'websocket:simple-broker'(prefix:"/queue, /topic"){} 'websocket:simple-broker'(prefix:"/queue, /topic"){}
'websocket:argument-resolvers' { 'websocket:argument-resolvers' {
'b:ref'(bean:'mcar') 'b:ref'(bean:'mcar')
} }
} }
websocket { websocket {
'intercept-message'(pattern:'/**',access:'permitAll') 'intercept-message'(pattern:'/**',access:'permitAll')
} }
when: 'websocket is sent to the myCustom endpoint' when: 'websocket is sent to the myCustom endpoint'
clientInboundChannel.send(message('/myCustom')) clientInboundChannel.send(message('/myCustom'))
then: 'myCustomArgument is resolved' then: 'myCustomArgument is resolved'
def controller = appContext.getBean(id) def controller = appContext.getBean(id)
controller.myCustomArgument!= null controller.myCustomArgument!= null
} }
def 'websocket with id does not integrate with clientInboundChannel'() { def 'websocket defaults pathMatcher'() {
setup: setup:
websocket([id:'inCsi']) { bean('pathMatcher',AntPathMatcher.name,['.'])
'intercept-message'(pattern:'/**',access:'denyAll') bean('testHandler', TestHandshakeHandler)
} xml.'websocket:message-broker'('path-matcher':'pathMatcher') {
'websocket:transport' {}
when: 'websocket:stomp-endpoint'(path:'/app') {
def success = clientInboundChannel.send(message('/denyAll')) 'websocket:handshake-handler'(ref:'testHandler') {}
}
then: 'websocket:simple-broker'(prefix:"/queue, /topic"){}
success }
xml.'websocket-message-broker' {
} 'intercept-message'(pattern:'/denyAll.*',access:'denyAll')
}
def 'websocket with id can be explicitly integrated with clientInboundChannel'() { createAppContext()
setup: 'websocket security explicitly setup'
xml.'websocket:message-broker' { when: 'sent to denyAll.a'
'websocket:transport' {} appContext.getBean(SimpAnnotationMethodMessageHandler)
'websocket:stomp-endpoint'(path:'/app') { clientInboundChannel.send(message('/denyAll.a'))
'websocket:sockjs' {}
} then: 'access is denied'
'websocket:simple-broker'(prefix:"/queue, /topic"){} MessageDeliveryException expected = thrown()
'websocket:client-inbound-channel' { expected.cause instanceof AccessDeniedException
'websocket:interceptors' {
'b:bean'(class:'org.springframework.security.messaging.context.SecurityContextChannelInterceptor'){} when: 'sent to denyAll.a.b'
'b:ref'(bean:'inCsi'){} clientInboundChannel.send(message('/denyAll.a.b'))
}
} then: 'access is allowed'
} noExceptionThrown()
xml.'websocket-message-broker'(id:'inCsi') { }
'intercept-message'(pattern:'/**',access:'denyAll')
} def 'websocket with id does not integrate with clientInboundChannel'() {
createAppContext() setup:
websocket([id:'inCsi']) {
when: 'intercept-message'(pattern:'/**',access:'denyAll')
clientInboundChannel.send(message('/denyAll')) }
then: when:
def e = thrown(MessageDeliveryException) def success = clientInboundChannel.send(message('/denyAll'))
e.cause instanceof AccessDeniedException
then:
} success
def 'automatic integration with clientInboundChannel does not override exisiting websocket:interceptors'() { }
setup:
mockBean(ChannelInterceptor,'mci') def 'websocket with id can be explicitly integrated with clientInboundChannel'() {
xml.'websocket:message-broker'('application-destination-prefix':'/app', setup: 'websocket security explicitly setup'
'user-destination-prefix':'/user') { xml.'websocket:message-broker' {
'websocket:transport' {} 'websocket:transport' {}
'websocket:stomp-endpoint'(path:'/foo') { 'websocket:stomp-endpoint'(path:'/app') {
'websocket:sockjs' {} 'websocket:sockjs' {}
} }
'websocket:simple-broker'(prefix:"/queue, /topic"){} 'websocket:simple-broker'(prefix:"/queue, /topic"){}
'websocket:client-inbound-channel' { 'websocket:client-inbound-channel' {
'websocket:interceptors' { 'websocket:interceptors' {
'b:ref'(bean:'mci'){} 'b:bean'(class:'org.springframework.security.messaging.context.SecurityContextChannelInterceptor'){}
} 'b:ref'(bean:'inCsi'){}
} }
} }
xml.'websocket-message-broker' { }
'intercept-message'(pattern:'/denyAll',access:'denyAll') xml.'websocket-message-broker'(id:'inCsi') {
'intercept-message'(pattern:'/permitAll',access:'permitAll') 'intercept-message'(pattern:'/**',access:'denyAll')
} }
createAppContext() createAppContext()
ChannelInterceptor mci = appContext.getBean('mci')
when: when:
Message<?> message = message('/permitAll') clientInboundChannel.send(message('/denyAll'))
clientInboundChannel.send(message)
then:
then: def e = thrown(MessageDeliveryException)
verify(mci).preSend(message, clientInboundChannel) || true e.cause instanceof AccessDeniedException
} }
def websocket(Map<String,Object> attrs=[:], Closure c) { def 'automatic integration with clientInboundChannel does not override exisiting websocket:interceptors'() {
bean('testHandler', TestHandshakeHandler) setup:
xml.'websocket:message-broker' { mockBean(ChannelInterceptor,'mci')
'websocket:transport' {} xml.'websocket:message-broker'('application-destination-prefix':'/app',
'websocket:stomp-endpoint'(path:'/app') { 'user-destination-prefix':'/user') {
'websocket:handshake-handler'(ref:'testHandler') {} 'websocket:transport' {}
'websocket:handshake-interceptors' { 'websocket:stomp-endpoint'(path:'/foo') {
'b:bean'('class':HttpSessionHandshakeInterceptor.name) {} 'websocket:sockjs' {}
} }
if(useSockJS) { 'websocket:simple-broker'(prefix:"/queue, /topic"){}
'websocket:sockjs' {} 'websocket:client-inbound-channel' {
} 'websocket:interceptors' {
} 'b:ref'(bean:'mci'){}
'websocket:simple-broker'(prefix:"/queue, /topic"){} }
} }
xml.'websocket-message-broker'(attrs, c) }
createAppContext() xml.'websocket-message-broker' {
} 'intercept-message'(pattern:'/denyAll',access:'denyAll')
'intercept-message'(pattern:'/permitAll',access:'permitAll')
def getClientInboundChannel() { }
appContext.getBean("clientInboundChannel") createAppContext()
} ChannelInterceptor mci = appContext.getBean('mci')
when:
def message(String destination, SimpMessageType type=SimpMessageType.MESSAGE) { Message<?> message = message('/permitAll')
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(type) clientInboundChannel.send(message)
message(headers, destination)
} then:
verify(mci).preSend(message, clientInboundChannel) || true
def message(SimpMessageHeaderAccessor headers, String destination) {
messageUser = new TestingAuthenticationToken('user','pass','ROLE_USER') }
headers.sessionId = '123'
headers.sessionAttributes = [:] def websocket(Map<String,Object> attrs=[:], Closure c) {
headers.destination = destination bean('testHandler', TestHandshakeHandler)
if(messageUser != null) { xml.'websocket:message-broker' {
headers.user = messageUser 'websocket:transport' {}
} 'websocket:stomp-endpoint'(path:'/app') {
if(csrfToken != null) { 'websocket:handshake-handler'(ref:'testHandler') {}
headers.sessionAttributes[CsrfToken.name] = csrfToken 'websocket:handshake-interceptors' {
} 'b:bean'('class':HttpSessionHandshakeInterceptor.name) {}
new GenericMessage<String>("hi",headers.messageHeaders) }
} if(useSockJS) {
'websocket:sockjs' {}
@Controller }
static class MyController { }
String authenticationPrincipal 'websocket:simple-broker'(prefix:"/queue, /topic"){}
MyCustomArgument myCustomArgument }
xml.'websocket-message-broker'(attrs, c)
@MessageMapping('/authentication') createAppContext()
public void authentication(@AuthenticationPrincipal String un) { }
this.authenticationPrincipal = un
} def getClientInboundChannel() {
appContext.getBean("clientInboundChannel")
@MessageMapping('/myCustom') }
public void myCustom(MyCustomArgument myCustomArgument) {
this.myCustomArgument = myCustomArgument def message(String destination, SimpMessageType type=SimpMessageType.MESSAGE) {
} SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(type)
} message(headers, destination)
}
static class MyCustomArgument {
MyCustomArgument(String notDefaultConstr) {} def message(SimpMessageHeaderAccessor headers, String destination) {
} messageUser = new TestingAuthenticationToken('user','pass','ROLE_USER')
headers.sessionId = '123'
static class MyCustomArgumentResolver implements HandlerMethodArgumentResolver { headers.sessionAttributes = [:]
headers.destination = destination
@Override if(messageUser != null) {
boolean supportsParameter(MethodParameter parameter) { headers.user = messageUser
parameter.parameterType.isAssignableFrom(MyCustomArgument) }
} if(csrfToken != null) {
headers.sessionAttributes[CsrfToken.name] = csrfToken
@Override }
Object resolveArgument(MethodParameter parameter, Message<?> message) throws Exception { new GenericMessage<String>("hi",headers.messageHeaders)
new MyCustomArgument("") }
}
} @Controller
static class MyController {
static class TestHandshakeHandler implements HandshakeHandler { String authenticationPrincipal
Map<String, Object> attributes; MyCustomArgument myCustomArgument
boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws HandshakeFailureException { @MessageMapping('/authentication')
this.attributes = attributes public void authentication(@AuthenticationPrincipal String un) {
if(wsHandler instanceof SockJsWebSocketHandler) { this.authenticationPrincipal = un
// work around SPR-12716 }
SockJsWebSocketHandler sockJs = (SockJsWebSocketHandler) wsHandler;
this.attributes = sockJs.sockJsSession.attributes @MessageMapping('/myCustom')
} public void myCustom(MyCustomArgument myCustomArgument) {
true this.myCustomArgument = myCustomArgument
} }
} }
/** static class MyCustomArgument {
* Changes the clientInboundChannel Executor to be synchronous MyCustomArgument(String notDefaultConstr) {}
*/ }
static class InboundExecutorPostProcessor implements BeanDefinitionRegistryPostProcessor {
static class MyCustomArgumentResolver implements HandlerMethodArgumentResolver {
@Override
void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException { @Override
BeanDefinition inbound = registry.getBeanDefinition("clientInboundChannel") boolean supportsParameter(MethodParameter parameter) {
inbound.getConstructorArgumentValues().addIndexedArgumentValue(0, new RootBeanDefinition(SyncTaskExecutor)); parameter.parameterType.isAssignableFrom(MyCustomArgument)
} }
@Override @Override
void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { Object resolveArgument(MethodParameter parameter, Message<?> message) throws Exception {
new MyCustomArgument("")
} }
} }
static class TestHandshakeHandler implements HandshakeHandler {
Map<String, Object> attributes;
boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws HandshakeFailureException {
this.attributes = attributes
if(wsHandler instanceof SockJsWebSocketHandler) {
// work around SPR-12716
SockJsWebSocketHandler sockJs = (SockJsWebSocketHandler) wsHandler;
this.attributes = sockJs.sockJsSession.attributes
}
true
}
}
/**
* Changes the clientInboundChannel Executor to be synchronous
*/
static class InboundExecutorPostProcessor implements BeanDefinitionRegistryPostProcessor {
@Override
void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
BeanDefinition inbound = registry.getBeanDefinition("clientInboundChannel")
inbound.getConstructorArgumentValues().addIndexedArgumentValue(0, new RootBeanDefinition(SyncTaskExecutor));
}
@Override
void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
}
}
} }