diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/builders/WebSecurity.java b/config/src/main/java/org/springframework/security/config/annotation/web/builders/WebSecurity.java index c86f054c7e..fad2eab83a 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/builders/WebSecurity.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/builders/WebSecurity.java @@ -57,6 +57,7 @@ import org.springframework.security.web.access.intercept.AuthorizationFilter; import org.springframework.security.web.access.intercept.FilterSecurityInterceptor; import org.springframework.security.web.debug.DebugFilter; import org.springframework.security.web.firewall.HttpFirewall; +import org.springframework.security.web.firewall.ObservationMarkingRequestRejectedHandler; import org.springframework.security.web.firewall.RequestRejectedHandler; import org.springframework.security.web.firewall.StrictHttpFirewall; import org.springframework.security.web.util.matcher.RequestMatcher; @@ -307,6 +308,10 @@ public final class WebSecurity extends AbstractConfiguredSecurityBuilder filterChains = (List) listFactoryBean.getPropertyValues() @@ -351,7 +352,8 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser { return customFilters; } - static void registerFilterChainProxyIfNecessary(ParserContext pc, Object source) { + static void registerFilterChainProxyIfNecessary(ParserContext pc, Element element) { + Object source = pc.extractSource(element); BeanDefinitionRegistry registry = pc.getRegistry(); if (registry.containsBeanDefinition(BeanIds.FILTER_CHAIN_PROXY)) { return; @@ -378,6 +380,7 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser { requestRejected.addConstructorArgValue("requestRejectedHandler"); requestRejected.addConstructorArgValue(BeanIds.FILTER_CHAIN_PROXY); requestRejected.addConstructorArgValue("requestRejectedHandler"); + requestRejected.addPropertyValue("observationRegistry", getObservationRegistry(element)); AbstractBeanDefinition requestRejectedBean = requestRejected.getBeanDefinition(); String requestRejectedPostProcessorName = pc.getReaderContext().generateBeanName(requestRejectedBean); registry.registerBeanDefinition(requestRejectedPostProcessorName, requestRejectedBean); @@ -391,7 +394,7 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser { return BeanDefinitionBuilder.rootBeanDefinition(ObservationRegistryFactory.class).getBeanDefinition(); } - static class RequestRejectedHandlerPostProcessor implements BeanDefinitionRegistryPostProcessor { + public static class RequestRejectedHandlerPostProcessor implements BeanDefinitionRegistryPostProcessor { private final String beanName; @@ -399,6 +402,8 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser { private final String targetPropertyName; + private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; + RequestRejectedHandlerPostProcessor(String beanName, String targetBeanName, String targetPropertyName) { this.beanName = beanName; this.targetBeanName = targetBeanName; @@ -412,6 +417,13 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser { beanDefinition.getPropertyValues().add(this.targetPropertyName, new RuntimeBeanReference(this.beanName)); } + else if (!this.observationRegistry.isNoop()) { + BeanDefinition observable = BeanDefinitionBuilder + .rootBeanDefinition(ObservationMarkingRequestRejectedHandler.class) + .addConstructorArgValue(this.observationRegistry).getBeanDefinition(); + BeanDefinition beanDefinition = registry.getBeanDefinition(this.targetBeanName); + beanDefinition.getPropertyValues().add(this.targetPropertyName, observable); + } } @Override @@ -419,6 +431,10 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser { } + public void setObservationRegistry(ObservationRegistry registry) { + this.observationRegistry = registry; + } + } /** diff --git a/web/src/main/java/org/springframework/security/web/firewall/ObservationMarkingRequestRejectedHandler.java b/web/src/main/java/org/springframework/security/web/firewall/ObservationMarkingRequestRejectedHandler.java new file mode 100644 index 0000000000..0f9eac70fc --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/firewall/ObservationMarkingRequestRejectedHandler.java @@ -0,0 +1,44 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.firewall; + +import java.io.IOException; + +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +public final class ObservationMarkingRequestRejectedHandler implements RequestRejectedHandler { + + private final ObservationRegistry registry; + + public ObservationMarkingRequestRejectedHandler(ObservationRegistry registry) { + this.registry = registry; + } + + @Override + public void handle(HttpServletRequest request, HttpServletResponse response, RequestRejectedException exception) + throws IOException, ServletException { + Observation observation = this.registry.getCurrentObservation(); + if (observation != null) { + observation.error(exception); + } + } + +}