From 4a9fa0337a5eb5dc2603afa45fd513fdf001aaea Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Fri, 1 May 2020 10:50:45 -0500 Subject: [PATCH] Allow Configure RequestRjectedHandler in XML Issue gh-5007 --- .../HttpSecurityBeanDefinitionParser.java | 46 ++++++++++++++++++- .../config/http/MiscHttpConfigTests.java | 17 +++++++ ...HttpConfigTests-RequestRejectedHandler.xml | 32 +++++++++++++ 3 files changed, 93 insertions(+), 2 deletions(-) create mode 100644 config/src/test/resources/org/springframework/security/config/http/MiscHttpConfigTests-RequestRejectedHandler.xml diff --git a/config/src/main/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParser.java index 256cae6dcb..d9f4a74ee2 100644 --- a/config/src/main/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParser.java @@ -24,14 +24,19 @@ import org.apache.commons.logging.LogFactory; import org.w3c.dom.Element; import org.springframework.beans.BeanMetadataElement; +import org.springframework.beans.BeansException; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanReference; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.config.ListFactoryBean; import org.springframework.beans.factory.config.MethodInvokingFactoryBean; import org.springframework.beans.factory.config.RuntimeBeanReference; import org.springframework.beans.factory.parsing.BeanComponentDefinition; import org.springframework.beans.factory.parsing.CompositeComponentDefinition; +import org.springframework.beans.factory.support.AbstractBeanDefinition; import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.BeanDefinitionRegistry; +import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; import org.springframework.beans.factory.support.ManagedList; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.factory.xml.BeanDefinitionParser; @@ -393,7 +398,8 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser { } static void registerFilterChainProxyIfNecessary(ParserContext pc, Object source) { - if (pc.getRegistry().containsBeanDefinition(BeanIds.FILTER_CHAIN_PROXY)) { + BeanDefinitionRegistry registry = pc.getRegistry(); + if (registry.containsBeanDefinition(BeanIds.FILTER_CHAIN_PROXY)) { return; } // Not already registered, so register the list of filter chains and the @@ -412,12 +418,48 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser { BeanDefinition fcpBean = fcpBldr.getBeanDefinition(); pc.registerBeanComponent(new BeanComponentDefinition(fcpBean, BeanIds.FILTER_CHAIN_PROXY)); - pc.getRegistry().registerAlias(BeanIds.FILTER_CHAIN_PROXY, + registry.registerAlias(BeanIds.FILTER_CHAIN_PROXY, BeanIds.SPRING_SECURITY_FILTER_CHAIN); + + BeanDefinitionBuilder requestRejected = BeanDefinitionBuilder.rootBeanDefinition(RequestRejectedHandlerPostProcessor.class); + requestRejected.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); + requestRejected.addConstructorArgValue("requestRejectedHandler"); + requestRejected.addConstructorArgValue(BeanIds.FILTER_CHAIN_PROXY); + requestRejected.addConstructorArgValue("requestRejectedHandler"); + AbstractBeanDefinition requestRejectedBean = requestRejected.getBeanDefinition(); + String requestRejectedPostProcessorName = pc.getReaderContext().generateBeanName(requestRejectedBean); + registry.registerBeanDefinition(requestRejectedPostProcessorName, requestRejectedBean); } } +class RequestRejectedHandlerPostProcessor implements BeanDefinitionRegistryPostProcessor { + private final String beanName; + + private final String targetBeanName; + + private final String targetPropertyName; + + RequestRejectedHandlerPostProcessor(String beanName, String targetBeanName, String targetPropertyName) { + this.beanName = beanName; + this.targetBeanName = targetBeanName; + this.targetPropertyName = targetPropertyName; + } + + @Override + public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException { + if (registry.containsBeanDefinition(this.beanName)) { + BeanDefinition beanDefinition = registry.getBeanDefinition(this.targetBeanName); + beanDefinition.getPropertyValues().add(this.targetPropertyName, new RuntimeBeanReference(this.beanName)); + } + } + + @Override + public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { + + } +} + class OrderDecorator implements Ordered { final BeanMetadataElement bean; final int order; diff --git a/config/src/test/java/org/springframework/security/config/http/MiscHttpConfigTests.java b/config/src/test/java/org/springframework/security/config/http/MiscHttpConfigTests.java index c4903195ab..4049067b0f 100644 --- a/config/src/test/java/org/springframework/security/config/http/MiscHttpConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/http/MiscHttpConfigTests.java @@ -94,6 +94,8 @@ import org.springframework.security.web.context.request.async.WebAsyncManagerInt import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.firewall.FirewalledRequest; import org.springframework.security.web.firewall.HttpFirewall; +import org.springframework.security.web.firewall.RequestRejectedException; +import org.springframework.security.web.firewall.RequestRejectedHandler; import org.springframework.security.web.header.HeaderWriterFilter; import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.security.web.savedrequest.RequestCacheAwareFilter; @@ -754,6 +756,21 @@ public class MiscHttpConfigTests { verify(firewall).getFirewalledResponse(any(HttpServletResponse.class)); } + @Test + public void getWhenUsingCustomRequestRejectedHandlerThenRequestRejectedHandlerIsInvoked() throws Exception { + this.spring.configLocations(xml("RequestRejectedHandler")).autowire(); + + HttpServletResponse response = new MockHttpServletResponse(); + + RequestRejectedException rejected = new RequestRejectedException("failed"); + HttpFirewall firewall = this.spring.getContext().getBean(HttpFirewall.class); + RequestRejectedHandler requestRejectedHandler = this.spring.getContext().getBean(RequestRejectedHandler.class); + when(firewall.getFirewalledRequest(any(HttpServletRequest.class))).thenThrow(rejected); + this.mvc.perform(get("/unprotected")); + + verify(requestRejectedHandler).handle(any(), any(), any()); + } + @Test public void getWhenUsingCustomAccessDecisionManagerThenAuthorizesAccordingly() throws Exception { this.spring.configLocations(xml("CustomAccessDecisionManager")).autowire(); diff --git a/config/src/test/resources/org/springframework/security/config/http/MiscHttpConfigTests-RequestRejectedHandler.xml b/config/src/test/resources/org/springframework/security/config/http/MiscHttpConfigTests-RequestRejectedHandler.xml new file mode 100644 index 0000000000..be62e9a47c --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/http/MiscHttpConfigTests-RequestRejectedHandler.xml @@ -0,0 +1,32 @@ + + + + + + + + + + +