SEC-2702: Add WebSocket Security XML Namespace Support

This commit is contained in:
Rob Winch 2014-11-25 09:45:32 -06:00
parent 09f6210c73
commit 8ad16b01f5
8 changed files with 541 additions and 0 deletions

View File

@ -56,4 +56,7 @@ public abstract class Elements {
public static final String HTTP_FIREWALL = "http-firewall";
public static final String HEADERS = "headers";
public static final String CSRF = "csrf";
public static final String MESSAGES = "messages";
public static final String INTERCEPT_MESSAGE = "message-interceptor";
}

View File

@ -38,6 +38,7 @@ import org.springframework.security.config.http.HttpSecurityBeanDefinitionParser
import org.springframework.security.config.ldap.LdapProviderBeanDefinitionParser;
import org.springframework.security.config.ldap.LdapServerBeanDefinitionParser;
import org.springframework.security.config.ldap.LdapUserServiceBeanDefinitionParser;
import org.springframework.security.config.message.MessageSecurityBeanDefinitionParser;
import org.springframework.security.config.method.GlobalMethodSecurityBeanDefinitionParser;
import org.springframework.security.config.method.InterceptMethodsBeanDefinitionDecorator;
import org.springframework.security.config.method.MethodSecurityMetadataSourceBeanDefinitionParser;
@ -56,6 +57,7 @@ import org.w3c.dom.Node;
*/
public final class SecurityNamespaceHandler implements NamespaceHandler {
private static final String FILTER_CHAIN_PROXY_CLASSNAME = "org.springframework.security.web.FilterChainProxy";
private static final String MESSAGE_CLASSNAME = "org.springframework.messaging.Message";
private final Log logger = LogFactory.getLog(getClass());
private final Map<String, BeanDefinitionParser> parsers = new HashMap<String, BeanDefinitionParser>();
private final BeanDefinitionDecorator interceptMethodsBDD = new InterceptMethodsBeanDefinitionDecorator();
@ -176,6 +178,10 @@ public final class SecurityNamespaceHandler implements NamespaceHandler {
parsers.put(Elements.FILTER_CHAIN, new FilterChainBeanDefinitionParser());
filterChainMapBDD = new FilterChainMapBeanDefinitionDecorator();
}
if(ClassUtils.isPresent(MESSAGE_CLASSNAME, getClass().getClassLoader())) {
parsers.put(Elements.MESSAGES, new MessageSecurityBeanDefinitionParser());
}
}
/**

View File

@ -0,0 +1,189 @@
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.config.message;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.BeansException;
import org.springframework.beans.PropertyValue;
import org.springframework.beans.factory.config.*;
import org.springframework.beans.factory.support.*;
import org.springframework.beans.factory.xml.BeanDefinitionParser;
import org.springframework.beans.factory.xml.ParserContext;
import org.springframework.beans.factory.xml.XmlReaderContext;
import org.springframework.messaging.simp.annotation.support.SimpAnnotationMethodMessageHandler;
import org.springframework.security.access.vote.ConsensusBased;
import org.springframework.security.config.Elements;
import org.springframework.security.messaging.access.expression.ExpressionBasedMessageSecurityMetadataSourceFactory;
import org.springframework.security.messaging.access.expression.MessageExpressionVoter;
import org.springframework.security.messaging.access.intercept.ChannelSecurityInterceptor;
import org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolver;
import org.springframework.security.messaging.context.SecurityContextChannelInterceptor;
import org.springframework.security.messaging.util.matcher.SimpDestinationMessageMatcher;
import org.springframework.util.StringUtils;
import org.springframework.util.xml.DomUtils;
import org.w3c.dom.Element;
import java.util.List;
/**
* Parses Spring Security's message namespace support. A simple example is:
*
* <code>
* &lt;messages&gt;
* &lt;message-interceptor pattern='/permitAll' access='permitAll' /&gt;
* &lt;message-interceptor pattern='/denyAll' access='denyAll' /&gt;
* &lt;/messages&gt;
* </code>
*
* <p>
* The above configuration will ensure that any SimpAnnotationMethodMessageHandler has the AuthenticationPrincipalArgumentResolver
* registered as a custom argument resolver. It also ensures that the SecurityContextChannelInterceptor is automatically
* registered for the clientInboundChannel. Last, it ensures that a ChannelSecurityInterceptor is registered with the
* clientInboundChannel.
* </p>
*
* <p>
* If finer control is necessary, the id attribute can be used as shown below:
* </p>
*
* <code>
* &lt;messages id="channelSecurityInterceptor"&gt;
* &lt;message-interceptor pattern='/permitAll' access='permitAll' /&gt;
* &lt;message-interceptor pattern='/denyAll' access='denyAll' /&gt;
* &lt;/messages&gt;
* </code>
*
* <p>
* Now the configuration will only create a bean named ChannelSecurityInterceptor and assign it to the id of
* channelSecurityInterceptor. Users can explicitly wire Spring Security using the standard Spring Messaging XML
* namespace support.
* </p>
*
* @author Rob Winch
* @since 4.0
*/
public final class MessageSecurityBeanDefinitionParser implements BeanDefinitionParser {
private static final Log logger = LogFactory.getLog(MessageSecurityBeanDefinitionParser.class);
private static final String ID_ATTR = "id";
private static final String PATTERN_ATTR = "pattern";
private static final String ACCESS_ATTR = "access";
/**
* @param element
* @param parserContext
* @return
*/
public BeanDefinition parse(Element element, ParserContext parserContext) {
BeanDefinitionRegistry registry = parserContext.getRegistry();
XmlReaderContext context = parserContext.getReaderContext();
ManagedMap<BeanDefinition,String> matcherToExpression = new ManagedMap<BeanDefinition, String>();
String id = element.getAttribute(ID_ATTR);
List<Element> interceptMessages = DomUtils.getChildElementsByTagName(element, Elements.INTERCEPT_MESSAGE);
for(Element interceptMessage : interceptMessages) {
String matcherPattern = interceptMessage.getAttribute(PATTERN_ATTR);
String accessExpression = interceptMessage.getAttribute(ACCESS_ATTR);
BeanDefinitionBuilder matcher = BeanDefinitionBuilder.rootBeanDefinition(SimpDestinationMessageMatcher.class);
matcher.addConstructorArgValue(matcherPattern);
matcherToExpression.put(matcher.getBeanDefinition(), accessExpression);
}
BeanDefinitionBuilder mds = BeanDefinitionBuilder.rootBeanDefinition(ExpressionBasedMessageSecurityMetadataSourceFactory.class);
mds.setFactoryMethod("createExpressionMessageMetadataSource");
mds.addConstructorArgValue(matcherToExpression);
String mdsId = context.registerWithGeneratedName(mds.getBeanDefinition());
ManagedList<BeanDefinition> voters = new ManagedList<BeanDefinition>();
voters.add(new RootBeanDefinition(MessageExpressionVoter.class));
BeanDefinitionBuilder adm = BeanDefinitionBuilder.rootBeanDefinition(ConsensusBased.class);
adm.addConstructorArgValue(voters);
BeanDefinitionBuilder inboundChannelSecurityInterceptor = BeanDefinitionBuilder.rootBeanDefinition(ChannelSecurityInterceptor.class);
inboundChannelSecurityInterceptor.addConstructorArgValue(registry.getBeanDefinition(mdsId));
inboundChannelSecurityInterceptor.addPropertyValue("accessDecisionManager", adm.getBeanDefinition());
String inSecurityInterceptorName = context.registerWithGeneratedName(inboundChannelSecurityInterceptor.getBeanDefinition());
if(StringUtils.hasText(id)) {
registry.registerAlias(inSecurityInterceptorName, id);
} else {
BeanDefinitionBuilder mspp = BeanDefinitionBuilder.rootBeanDefinition(MessageSecurityPostProcessor.class);
mspp.addConstructorArgValue(inSecurityInterceptorName);
context.registerWithGeneratedName(mspp.getBeanDefinition());
}
return null;
}
static class MessageSecurityPostProcessor implements BeanDefinitionRegistryPostProcessor {
private static final String CLIENT_INBOUND_CHANNEL_BEAN_ID = "clientInboundChannel";
private static final String INTERCEPTORS_PROP = "interceptors";
private static final String CUSTOM_ARG_RESOLVERS_PROP = "customArgumentResolvers";
private final String inboundSecurityInterceptorId;
public MessageSecurityPostProcessor(String inboundSecurityInterceptorId) {
this.inboundSecurityInterceptorId = inboundSecurityInterceptorId;
}
@Override
public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
String[] beanNames = registry.getBeanDefinitionNames();
for(String beanName : beanNames) {
BeanDefinition bd = registry.getBeanDefinition(beanName);
if(bd.getBeanClassName().equals(SimpAnnotationMethodMessageHandler.class.getName())) {
PropertyValue current = bd.getPropertyValues().getPropertyValue(CUSTOM_ARG_RESOLVERS_PROP);
ManagedList<Object> argResolvers = new ManagedList<Object>();
if(current != null) {
argResolvers.addAll((ManagedList<?>)current.getValue());
}
argResolvers.add(new RootBeanDefinition(AuthenticationPrincipalArgumentResolver.class));
bd.getPropertyValues().add(CUSTOM_ARG_RESOLVERS_PROP, argResolvers);
}
}
if(!registry.containsBeanDefinition(CLIENT_INBOUND_CHANNEL_BEAN_ID)) {
return;
}
ManagedList<Object> interceptors = new ManagedList();
interceptors.add(new RootBeanDefinition(SecurityContextChannelInterceptor.class));
interceptors.add(registry.getBeanDefinition(inboundSecurityInterceptorId));
BeanDefinition inboundChannel = registry.getBeanDefinition(CLIENT_INBOUND_CHANNEL_BEAN_ID);
PropertyValue currentInterceptorsPv = inboundChannel.getPropertyValues().getPropertyValue(INTERCEPTORS_PROP);
if(currentInterceptorsPv != null) {
ManagedList<?> currentInterceptors = (ManagedList<?>) currentInterceptorsPv.getValue();
interceptors.addAll(currentInterceptors);
}
inboundChannel.getPropertyValues().add(INTERCEPTORS_PROP, interceptors);
}
@Override
public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
}
}
}

View File

@ -275,6 +275,25 @@ protect-pointcut.attlist &=
## Access configuration attributes list that applies to all methods matching the pointcut, e.g. "ROLE_A,ROLE_B"
attribute access {xsd:token}
messages =
## Messages
element messages { messages.attrlist, (message-interceptor*) }
messages.attrlist &=
## the id
attribute id {xsd:token}?
message-interceptor =
## Message
element message-interceptor {message-interceptor.attrlist}
message-interceptor.attrlist &=
## pattern
attribute pattern {xsd:token}?
message-interceptor.attrlist &=
## access
attribute access {xsd:token}?
http-firewall =
## Allows a custom instance of HttpFirewall to be injected into the FilterChainProxy created by the namespace.
element http-firewall {ref}

View File

@ -853,6 +853,49 @@
</xs:annotation>
</xs:attribute>
</xs:attributeGroup>
<xs:element name="messages">
<xs:annotation>
<xs:documentation>Messages
</xs:documentation>
</xs:annotation>
<xs:complexType>
<xs:sequence>
<xs:element minOccurs="0" maxOccurs="unbounded" ref="security:message-interceptor"/>
</xs:sequence>
<xs:attributeGroup ref="security:messages.attrlist"/>
</xs:complexType>
</xs:element>
<xs:attributeGroup name="messages.attrlist">
<xs:attribute name="id" type="xs:token">
<xs:annotation>
<xs:documentation>the id
</xs:documentation>
</xs:annotation>
</xs:attribute>
</xs:attributeGroup>
<xs:element name="message-interceptor">
<xs:annotation>
<xs:documentation>Message
</xs:documentation>
</xs:annotation>
<xs:complexType>
<xs:attributeGroup ref="security:message-interceptor.attrlist"/>
</xs:complexType>
</xs:element>
<xs:attributeGroup name="message-interceptor.attrlist">
<xs:attribute name="pattern" type="xs:token">
<xs:annotation>
<xs:documentation>pattern
</xs:documentation>
</xs:annotation>
</xs:attribute>
<xs:attribute name="access" type="xs:token">
<xs:annotation>
<xs:documentation>access
</xs:documentation>
</xs:annotation>
</xs:attribute>
</xs:attributeGroup>
<xs:element name="http-firewall">
<xs:annotation>
<xs:documentation>Allows a custom instance of HttpFirewall to be injected into the FilterChainProxy created

View File

@ -0,0 +1,266 @@
package org.springframework.security.config.message
import org.springframework.beans.BeansException
import org.springframework.beans.factory.config.BeanDefinition
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory
import org.springframework.beans.factory.support.BeanDefinitionRegistry
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor
import org.springframework.beans.factory.support.RootBeanDefinition
import org.springframework.core.MethodParameter
import org.springframework.core.task.SyncTaskExecutor
import org.springframework.http.server.ServerHttpRequest
import org.springframework.http.server.ServerHttpResponse
import org.springframework.messaging.handler.annotation.MessageMapping
import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver
import org.springframework.security.core.Authentication
import org.springframework.security.core.annotation.AuthenticationPrincipal
import org.springframework.stereotype.Controller
import org.springframework.web.socket.WebSocketHandler
import org.springframework.web.socket.server.HandshakeFailureException
import org.springframework.web.socket.server.HandshakeHandler
import static org.mockito.Mockito.*
import org.springframework.messaging.Message
import org.springframework.messaging.MessageDeliveryException
import org.springframework.messaging.simp.SimpMessageHeaderAccessor
import org.springframework.messaging.support.ChannelInterceptor
import org.springframework.messaging.support.GenericMessage
import org.springframework.security.access.AccessDeniedException
import org.springframework.security.authentication.TestingAuthenticationToken
import org.springframework.security.config.AbstractXmlConfigTests
import org.springframework.security.core.context.SecurityContextHolder
/**
*
* @author Rob Winch
*/
class MessagesConfigTests extends AbstractXmlConfigTests {
Authentication messageUser
def cleanup() {
SecurityContextHolder.clearContext()
}
def 'messages with no id automatically integrates with clientInboundChannel'() {
setup:
messages {
'message-interceptor'(pattern:'/permitAll',access:'permitAll')
'message-interceptor'(pattern:'/denyAll',access:'denyAll')
}
when: 'message is sent to the denyAll endpoint'
clientInboundChannel.send(message('/denyAll'))
then: 'access is denied to the denyAll endpoint'
def e = thrown(MessageDeliveryException)
e.cause instanceof AccessDeniedException
and: 'access is granted to the permitAll endpoint'
clientInboundChannel.send(message('/permitAll'))
}
def 'messages with no id automatically adds Authentication argument resolver'() {
setup:
def id = 'authenticationController'
bean(id,MyController)
bean('inPostProcessor',InboundExecutorPostProcessor)
messages {
'message-interceptor'(pattern:'/**',access:'permitAll')
}
when: 'message is sent to the authentication endpoint'
clientInboundChannel.send(message('/authentication'))
then: 'the AuthenticationPrincipal is resolved'
def controller = appContext.getBean(id)
controller.authenticationPrincipal == messageUser.name
}
def 'messages with no id does not override customArgumentResolvers'() {
setup:
def id = 'authenticationController'
bean(id,MyController)
bean('inPostProcessor',InboundExecutorPostProcessor)
bean('mcar', MyCustomArgumentResolver)
xml.'websocket:message-broker' {
'websocket:transport' {}
'websocket:stomp-endpoint'(path:'/app') {
'websocket:handshake-handler'(ref:'testHandler') {}
}
'websocket:simple-broker'(prefix:"/queue, /topic"){}
'websocket:argument-resolvers' {
'b:ref'(bean:'mcar')
}
}
messages {
'message-interceptor'(pattern:'/**',access:'permitAll')
}
when: 'message is sent to the myCustom endpoint'
clientInboundChannel.send(message('/myCustom'))
then: 'myCustomArgument is resolved'
def controller = appContext.getBean(id)
controller.myCustomArgument!= null
}
def 'messages with id does not integrate with clientInboundChannel'() {
setup:
messages([id:'inCsi']) {
'message-interceptor'(pattern:'/**',access:'denyAll')
}
when:
def success = clientInboundChannel.send(message('/denyAll'))
then:
success
}
def 'messages with id can be explicitly integrated with clientInboundChannel'() {
setup: 'message security explicitly setup'
xml.'websocket:message-broker' {
'websocket:transport' {}
'websocket:stomp-endpoint'(path:'/app') {
'websocket:sockjs' {}
}
'websocket:simple-broker'(prefix:"/queue, /topic"){}
'websocket:client-inbound-channel' {
'websocket:interceptors' {
'b:bean'(class:'org.springframework.security.messaging.context.SecurityContextChannelInterceptor'){}
'b:ref'(bean:'inCsi'){}
}
}
}
xml.messages(id:'inCsi') {
'message-interceptor'(pattern:'/**',access:'denyAll')
}
createAppContext()
when:
clientInboundChannel.send(message('/denyAll'))
then:
def e = thrown(MessageDeliveryException)
e.cause instanceof AccessDeniedException
}
def 'automatic integration with clientInboundChannel does not override exisiting websocket:interceptors'() {
setup:
mockBean(ChannelInterceptor,'mci')
xml.'websocket:message-broker'('application-destination-prefix':'/app',
'user-destination-prefix':'/user') {
'websocket:transport' {}
'websocket:stomp-endpoint'(path:'/foo') {
'websocket:sockjs' {}
}
'websocket:simple-broker'(prefix:"/queue, /topic"){}
'websocket:client-inbound-channel' {
'websocket:interceptors' {
'b:ref'(bean:'mci'){}
}
}
}
xml.messages {
'message-interceptor'(pattern:'/denyAll',access:'denyAll')
'message-interceptor'(pattern:'/permitAll',access:'permitAll')
}
createAppContext()
ChannelInterceptor mci = appContext.getBean('mci')
when:
Message<?> message = message('/permitAll')
clientInboundChannel.send(message)
then:
verify(mci).preSend(message, clientInboundChannel) || true
}
def messages(Map<String,Object> attrs=[:], Closure c) {
bean('testHandler', TestHandshakeHandler)
xml.'websocket:message-broker' {
'websocket:transport' {}
'websocket:stomp-endpoint'(path:'/app') {
'websocket:handshake-handler'(ref:'testHandler') {}
}
'websocket:simple-broker'(prefix:"/queue, /topic"){}
}
xml.messages(attrs, c)
createAppContext()
}
def getClientInboundChannel() {
appContext.getBean("clientInboundChannel")
}
def message(String destination) {
messageUser = new TestingAuthenticationToken('user','pass','ROLE_USER')
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create()
headers.sessionId = '123'
headers.sessionAttributes = [:]
headers.destination = destination
headers.user = messageUser
new GenericMessage<String>("hi",headers.messageHeaders)
}
@Controller
static class MyController {
String authenticationPrincipal
MyCustomArgument myCustomArgument
@MessageMapping('/authentication')
public void authentication(@AuthenticationPrincipal String un) {
this.authenticationPrincipal = un
}
@MessageMapping('/myCustom')
public void myCustom(MyCustomArgument myCustomArgument) {
this.myCustomArgument = myCustomArgument
}
}
static class MyCustomArgument {
MyCustomArgument(String notDefaultConstr) {}
}
static class MyCustomArgumentResolver implements HandlerMethodArgumentResolver {
@Override
boolean supportsParameter(MethodParameter parameter) {
parameter.parameterType.isAssignableFrom(MyCustomArgument)
}
@Override
Object resolveArgument(MethodParameter parameter, Message<?> message) throws Exception {
new MyCustomArgument("")
}
}
static class TestHandshakeHandler implements HandshakeHandler {
@Override
boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws HandshakeFailureException {
true
}
}
/**
* Changes the clientInboundChannel Executor to be synchronous
*/
static class InboundExecutorPostProcessor implements BeanDefinitionRegistryPostProcessor {
@Override
void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
BeanDefinition inbound = registry.getBeanDefinition("clientInboundChannel")
inbound.getConstructorArgumentValues().addIndexedArgumentValue(0, new RootBeanDefinition(SyncTaskExecutor));
}
@Override
void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
}
}
}

View File

@ -13,6 +13,7 @@ import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.powermock.reflect.internal.WhiteboxImpl;
import org.springframework.beans.factory.parsing.BeanDefinitionParsingException;
import org.springframework.messaging.Message;
import org.springframework.security.config.util.InMemoryXmlApplicationContext;
import org.springframework.util.ClassUtils;
@ -122,4 +123,16 @@ public class SecurityNamespaceHandlerTests {
XML_AUTHENTICATION_MANAGER);
// should load just fine since no http block
}
@Test
public void messageNotFoundExceptionNoMessageBlock() throws Exception {
String className = FILTER_CHAIN_PROXY_CLASSNAME;
spy(ClassUtils.class);
doThrow(new ClassNotFoundException(className)).when(ClassUtils.class,"forName",eq(Message.class.getName()),any(ClassLoader.class));
new InMemoryXmlApplicationContext(
XML_AUTHENTICATION_MANAGER);
// should load just fine since no message block
}
}

View File

@ -30,9 +30,11 @@ public class InMemoryXmlApplicationContext extends AbstractXmlApplicationContext
" xmlns:context='http://www.springframework.org/schema/context'\n" +
" xmlns:b='http://www.springframework.org/schema/beans'\n" +
" xmlns:aop='http://www.springframework.org/schema/aop'\n" +
" xmlns:websocket='http://www.springframework.org/schema/websocket'\n" +
" xmlns:xsi='http://www.w3.org/2001/XMLSchema-instance'\n" +
" xsi:schemaLocation='http://www.springframework.org/schema/beans http://www.springframework.org/schema/beans/spring-beans-2.5.xsd\n" +
"http://www.springframework.org/schema/aop http://www.springframework.org/schema/aop/spring-aop-2.5.xsd\n" +
"http://www.springframework.org/schema/websocket http://www.springframework.org/schema/websocket/spring-websocket.xsd\n" +
"http://www.springframework.org/schema/context http://www.springframework.org/schema/context/spring-context-2.5.xsd\n" +
"http://www.springframework.org/schema/security http://www.springframework.org/schema/security/spring-security-";
private static final String BEANS_CLOSE = "</b:beans>\n";