diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/builders/WebSecurity.java b/config/src/main/java/org/springframework/security/config/annotation/web/builders/WebSecurity.java index 6191798b3d..c86f054c7e 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/builders/WebSecurity.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/builders/WebSecurity.java @@ -19,6 +19,7 @@ package org.springframework.security.config.annotation.web.builders; import java.util.ArrayList; import java.util.List; +import io.micrometer.observation.ObservationRegistry; import jakarta.servlet.Filter; import jakarta.servlet.ServletContext; import jakarta.servlet.http.HttpServletRequest; @@ -45,6 +46,7 @@ import org.springframework.security.core.context.SecurityContext; import org.springframework.security.web.DefaultSecurityFilterChain; import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.FilterInvocation; +import org.springframework.security.web.ObservationFilterChainDecorator; import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.access.AuthorizationManagerWebInvocationPrivilegeEvaluator; import org.springframework.security.web.access.DefaultWebInvocationPrivilegeEvaluator; @@ -101,6 +103,8 @@ public final class WebSecurity extends AbstractConfiguredSecurityBuilder expressionHandler = this.defaultWebSecurityExpressionHandler; @@ -303,6 +307,7 @@ public final class WebSecurity extends AbstractConfiguredSecurityBuilder securityWebFilterChains; + private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; + @Autowired ApplicationContext context; @@ -63,10 +68,19 @@ class WebFluxSecurityConfiguration { this.securityWebFilterChains = securityWebFilterChains; } + @Autowired(required = false) + void setObservationRegistry(ObservationRegistry observationRegistry) { + this.observationRegistry = observationRegistry; + } + @Bean(SPRING_SECURITY_WEBFILTERCHAINFILTER_BEAN_NAME) @Order(WEB_FILTER_CHAIN_FILTER_ORDER) WebFilterChainProxy springSecurityWebFilterChainFilter() { - return new WebFilterChainProxy(getSecurityWebFilterChains()); + WebFilterChainProxy proxy = new WebFilterChainProxy(getSecurityWebFilterChains()); + if (!this.observationRegistry.isNoop()) { + proxy.setFilterChainDecorator(new ObservationWebFilterChainDecorator(this.observationRegistry)); + } + return proxy; } @Bean(name = AbstractView.REQUEST_DATA_VALUE_PROCESSOR_BEAN_NAME) diff --git a/config/src/main/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParser.java index 44c6007c19..4ff9599d88 100644 --- a/config/src/main/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParser.java @@ -56,6 +56,7 @@ import org.springframework.security.config.Elements; import org.springframework.security.config.authentication.AuthenticationManagerFactoryBean; import org.springframework.security.web.DefaultSecurityFilterChain; import org.springframework.security.web.FilterChainProxy; +import org.springframework.security.web.ObservationFilterChainDecorator; import org.springframework.security.web.PortResolverImpl; import org.springframework.security.web.util.matcher.AnyRequestMatcher; import org.springframework.util.StringUtils; @@ -363,6 +364,10 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser { fcpBldr.getRawBeanDefinition().setSource(source); fcpBldr.addConstructorArgReference(BeanIds.FILTER_CHAINS); fcpBldr.addPropertyValue("filterChainValidator", new RootBeanDefinition(DefaultFilterChainValidator.class)); + BeanDefinition filterChainDecorator = BeanDefinitionBuilder + .rootBeanDefinition(FilterChainDecoratorFactory.class) + .addPropertyValue("observationRegistry", getObservationRegistry(element)).getBeanDefinition(); + fcpBldr.addPropertyValue("filterChainDecorator", filterChainDecorator); BeanDefinition fcpBean = fcpBldr.getBeanDefinition(); pc.registerBeanComponent(new BeanComponentDefinition(fcpBean, BeanIds.FILTER_CHAIN_PROXY)); registry.registerAlias(BeanIds.FILTER_CHAIN_PROXY, BeanIds.SPRING_SECURITY_FILTER_CHAIN); @@ -509,4 +514,28 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser { } + public static final class FilterChainDecoratorFactory + implements FactoryBean { + + private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; + + @Override + public FilterChainProxy.FilterChainDecorator getObject() throws Exception { + if (this.observationRegistry.isNoop()) { + return new FilterChainProxy.VirtualFilterChainDecorator(); + } + return new ObservationFilterChainDecorator(this.observationRegistry); + } + + @Override + public Class getObjectType() { + return FilterChainProxy.FilterChainDecorator.class; + } + + public void setObservationRegistry(ObservationRegistry registry) { + this.observationRegistry = registry; + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityObservationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityObservationTests.java new file mode 100644 index 0000000000..7cc85c7967 --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityObservationTests.java @@ -0,0 +1,115 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.web.configurers; + +import java.util.Iterator; + +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationHandler; +import io.micrometer.observation.ObservationRegistry; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.test.SpringTestContext; +import org.springframework.security.config.test.SpringTestContextExtension; +import org.springframework.security.core.userdetails.User; +import org.springframework.security.core.userdetails.UserDetailsService; +import org.springframework.security.provisioning.InMemoryUserDetailsManager; +import org.springframework.security.web.SecurityFilterChain; +import org.springframework.test.web.servlet.MockMvc; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.springframework.security.config.Customizer.withDefaults; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +/** + * @author Josh Cummings + * + */ +@ExtendWith(SpringTestContextExtension.class) +public class HttpSecurityObservationTests { + + @Autowired + MockMvc mvc; + + public final SpringTestContext spring = new SpringTestContext(this); + + @Test + public void getWhenUsingObservationRegistryThenObservesRequest() throws Exception { + this.spring.register(ObservationRegistryConfig.class).autowire(); + // @formatter:off + this.mvc.perform(get("/").with(httpBasic("user", "password"))) + .andExpect(status().isNotFound()); + // @formatter:on + ObservationHandler handler = this.spring.getContext().getBean(ObservationHandler.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(Observation.Context.class); + verify(handler, times(5)).onStart(captor.capture()); + Iterator contexts = captor.getAllValues().iterator(); + assertThat(contexts.next().getContextualName()).isEqualTo("spring.security.http.chains.before"); + assertThat(contexts.next().getName()).isEqualTo("spring.security.authentications"); + assertThat(contexts.next().getName()).isEqualTo("spring.security.authorizations"); + assertThat(contexts.next().getName()).isEqualTo("spring.security.http.secured.requests"); + assertThat(contexts.next().getContextualName()).isEqualTo("spring.security.http.chains.after"); + } + + @EnableWebSecurity + @Configuration + static class ObservationRegistryConfig { + + private ObservationHandler handler = mock(ObservationHandler.class); + + @Bean + SecurityFilterChain app(HttpSecurity http) throws Exception { + http.httpBasic(withDefaults()).authorizeHttpRequests((requests) -> requests.anyRequest().authenticated()); + return http.build(); + } + + @Bean + UserDetailsService userDetailsService() { + return new InMemoryUserDetailsManager( + User.withDefaultPasswordEncoder().username("user").password("password").authorities("app").build()); + } + + @Bean + ObservationHandler observationHandler() { + return this.handler; + } + + @Bean + ObservationRegistry observationRegistry() { + given(this.handler.supportsContext(any())).willReturn(true); + ObservationRegistry registry = ObservationRegistry.create(); + registry.observationConfig().observationHandler(this.handler); + return registry; + } + + } + +} diff --git a/config/src/test/java/org/springframework/security/config/http/HttpConfigTests.java b/config/src/test/java/org/springframework/security/config/http/HttpConfigTests.java index 8be9526bc4..9056fb6a77 100644 --- a/config/src/test/java/org/springframework/security/config/http/HttpConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/http/HttpConfigTests.java @@ -16,13 +16,20 @@ package org.springframework.security.config.http; +import java.util.Iterator; + +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationHandler; +import io.micrometer.observation.ObservationRegistry; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponseWrapper; import org.apache.http.HttpStatus; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; @@ -36,7 +43,10 @@ import org.springframework.test.web.servlet.MockMvc; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; @@ -101,6 +111,24 @@ public class HttpConfigTests { assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/login"); } + @Test + public void getWhenUsingObservationRegistryThenObservesRequest() throws Exception { + this.spring.configLocations(this.xml("WithObservationRegistry")).autowire(); + // @formatter:off + this.mvc.perform(get("/").with(httpBasic("user", "password"))) + .andExpect(status().isNotFound()); + // @formatter:on + ObservationHandler handler = this.spring.getContext().getBean(ObservationHandler.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(Observation.Context.class); + verify(handler, times(5)).onStart(captor.capture()); + Iterator contexts = captor.getAllValues().iterator(); + assertThat(contexts.next().getContextualName()).isEqualTo("spring.security.http.chains.before"); + assertThat(contexts.next().getName()).isEqualTo("spring.security.authentications"); + assertThat(contexts.next().getName()).isEqualTo("spring.security.authorizations"); + assertThat(contexts.next().getName()).isEqualTo("spring.security.http.secured.requests"); + assertThat(contexts.next().getContextualName()).isEqualTo("spring.security.http.chains.after"); + } + private String xml(String configName) { return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; } @@ -133,4 +161,27 @@ public class HttpConfigTests { } + public static final class MockObservationRegistry implements FactoryBean { + + private ObservationHandler handler = mock(ObservationHandler.class); + + @Override + public ObservationRegistry getObject() { + ObservationRegistry registry = ObservationRegistry.create(); + registry.observationConfig().observationHandler(this.handler); + given(this.handler.supportsContext(any())).willReturn(true); + return registry; + } + + @Override + public Class getObjectType() { + return ObservationRegistry.class; + } + + public void setHandler(ObservationHandler handler) { + this.handler = handler; + } + + } + } diff --git a/config/src/test/resources/org/springframework/security/config/http/HttpConfigTests-WithObservationRegistry.xml b/config/src/test/resources/org/springframework/security/config/http/HttpConfigTests-WithObservationRegistry.xml new file mode 100644 index 0000000000..d1d1839f9b --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/http/HttpConfigTests-WithObservationRegistry.xml @@ -0,0 +1,40 @@ + + + + + + + + + + + + + + + + + + + diff --git a/web/src/main/java/org/springframework/security/web/FilterChainProxy.java b/web/src/main/java/org/springframework/security/web/FilterChainProxy.java index 594812d860..19415a24d3 100644 --- a/web/src/main/java/org/springframework/security/web/FilterChainProxy.java +++ b/web/src/main/java/org/springframework/security/web/FilterChainProxy.java @@ -160,6 +160,8 @@ public class FilterChainProxy extends GenericFilterBean { private ThrowableAnalyzer throwableAnalyzer = new ThrowableAnalyzer(); + private FilterChainDecorator filterChainDecorator = new VirtualFilterChainDecorator(); + public FilterChainProxy() { } @@ -214,14 +216,21 @@ public class FilterChainProxy extends GenericFilterBean { logger.trace(LogMessage.of(() -> "No security for " + requestLine(firewallRequest))); } firewallRequest.reset(); - chain.doFilter(firewallRequest, firewallResponse); + this.filterChainDecorator.decorate(chain).doFilter(firewallRequest, firewallResponse); return; } if (logger.isDebugEnabled()) { logger.debug(LogMessage.of(() -> "Securing " + requestLine(firewallRequest))); } - VirtualFilterChain virtualFilterChain = new VirtualFilterChain(firewallRequest, chain, filters); - virtualFilterChain.doFilter(firewallRequest, firewallResponse); + FilterChain reset = (req, res) -> { + if (logger.isDebugEnabled()) { + logger.debug(LogMessage.of(() -> "Secured " + requestLine(firewallRequest))); + } + // Deactivate path stripping as we exit the security filter chain + firewallRequest.reset(); + chain.doFilter(req, res); + }; + this.filterChainDecorator.decorate(reset, filters).doFilter(firewallRequest, firewallResponse); } /** @@ -249,7 +258,7 @@ public class FilterChainProxy extends GenericFilterBean { * @return matching filter list */ public List getFilters(String url) { - return getFilters(this.firewall.getFirewalledRequest((new FilterInvocation(url, "GET").getRequest()))); + return getFilters(this.firewall.getFirewalledRequest(new FilterInvocation(url, "GET").getRequest())); } /** @@ -281,6 +290,20 @@ public class FilterChainProxy extends GenericFilterBean { this.filterChainValidator = filterChainValidator; } + /** + * Used to decorate the original {@link FilterChain} for each request + * + *

+ * By default, this decorates the filter chain with a {@link VirtualFilterChain} that + * iterates through security filters and then delegates to the original chain + * @param filterChainDecorator the strategy for constructing the filter chain + * @since 6.0 + */ + public void setFilterChainDecorator(FilterChainDecorator filterChainDecorator) { + Assert.notNull(filterChainDecorator, "filterChainDecorator cannot be null"); + this.filterChainDecorator = filterChainDecorator; + } + /** * Sets the "firewall" implementation which will be used to validate and wrap (or * potentially reject) the incoming requests. The default implementation should be @@ -326,36 +349,27 @@ public class FilterChainProxy extends GenericFilterBean { private final List additionalFilters; - private final FirewalledRequest firewalledRequest; - private final int size; private int currentPosition = 0; - private VirtualFilterChain(FirewalledRequest firewalledRequest, FilterChain chain, - List additionalFilters) { + private VirtualFilterChain(FilterChain chain, List additionalFilters) { this.originalChain = chain; this.additionalFilters = additionalFilters; this.size = additionalFilters.size(); - this.firewalledRequest = firewalledRequest; } @Override public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException { if (this.currentPosition == this.size) { - if (logger.isDebugEnabled()) { - logger.debug(LogMessage.of(() -> "Secured " + requestLine(this.firewalledRequest))); - } - // Deactivate path stripping as we exit the security filter chain - this.firewalledRequest.reset(); this.originalChain.doFilter(request, response); return; } this.currentPosition++; Filter nextFilter = this.additionalFilters.get(this.currentPosition - 1); if (logger.isTraceEnabled()) { - logger.trace(LogMessage.format("Invoking %s (%d/%d)", nextFilter.getClass().getSimpleName(), - this.currentPosition, this.size)); + String name = nextFilter.getClass().getSimpleName(); + logger.trace(LogMessage.format("Invoking %s (%d/%d)", name, this.currentPosition, this.size)); } nextFilter.doFilter(request, response, this); } @@ -376,4 +390,61 @@ public class FilterChainProxy extends GenericFilterBean { } + /** + * A strategy for decorating the provided filter chain with one that accounts for the + * {@link SecurityFilterChain} for a given request. + * + * @author Josh Cummings + * @since 6.0 + */ + public interface FilterChainDecorator { + + /** + * Provide a new {@link FilterChain} that accounts for needed security + * considerations when there are no security filters. + * @param original the original {@link FilterChain} + * @return a security-enabled {@link FilterChain} + */ + default FilterChain decorate(FilterChain original) { + return decorate(original, Collections.emptyList()); + } + + /** + * Provide a new {@link FilterChain} that accounts for the provided filters as + * well as teh original filter chain. + * @param original the original {@link FilterChain} + * @param filters the security filters + * @return a security-enabled {@link FilterChain} that includes the provided + * filters + */ + FilterChain decorate(FilterChain original, List filters); + + } + + /** + * A {@link FilterChainDecorator} that uses the {@link VirtualFilterChain} + * + * @author Josh Cummings + * @since 6.0 + */ + public static final class VirtualFilterChainDecorator implements FilterChainDecorator { + + /** + * {@inheritDoc} + */ + @Override + public FilterChain decorate(FilterChain original) { + return original; + } + + /** + * {@inheritDoc} + */ + @Override + public FilterChain decorate(FilterChain original, List filters) { + return new VirtualFilterChain(original, filters); + } + + } + } diff --git a/web/src/main/java/org/springframework/security/web/ObservationFilterChainDecorator.java b/web/src/main/java/org/springframework/security/web/ObservationFilterChainDecorator.java new file mode 100644 index 0000000000..8f50c52a11 --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/ObservationFilterChainDecorator.java @@ -0,0 +1,525 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +import io.micrometer.common.KeyValues; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationConvention; +import io.micrometer.observation.ObservationRegistry; +import jakarta.servlet.Filter; +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; +import jakarta.servlet.http.HttpServletRequest; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.core.log.LogMessage; +import org.springframework.security.web.util.UrlUtils; + +/** + * A {@link org.springframework.security.web.server.FilterChainProxy.FilterChainDecorator} + * that wraps the chain in before and after observations + * + * @author Josh Cummings + * @since 6.0 + */ +public final class ObservationFilterChainDecorator implements FilterChainProxy.FilterChainDecorator { + + private static final Log logger = LogFactory.getLog(FilterChainProxy.class); + + private static final String ATTRIBUTE = ObservationFilterChainDecorator.class + ".observation"; + + static final String UNSECURED_OBSERVATION_NAME = "spring.security.http.unsecured.requests"; + + static final String SECURED_OBSERVATION_NAME = "spring.security.http.secured.requests"; + + private final ObservationRegistry registry; + + public ObservationFilterChainDecorator(ObservationRegistry registry) { + this.registry = registry; + } + + @Override + public FilterChain decorate(FilterChain original) { + return wrapUnsecured(original); + } + + @Override + public FilterChain decorate(FilterChain original, List filters) { + return new VirtualFilterChain(wrapSecured(original), wrap(filters)); + } + + private FilterChain wrapSecured(FilterChain original) { + return (req, res) -> { + AroundFilterObservation parent = observation((HttpServletRequest) req); + Observation observation = Observation.createNotStarted(SECURED_OBSERVATION_NAME, this.registry); + parent.wrap(FilterObservation.create(observation).wrap(original)).doFilter(req, res); + }; + } + + private FilterChain wrapUnsecured(FilterChain original) { + return (req, res) -> { + Observation observation = Observation.createNotStarted(UNSECURED_OBSERVATION_NAME, this.registry); + FilterObservation.create(observation).wrap(original).doFilter(req, res); + }; + } + + private List wrap(List filters) { + int size = filters.size(); + List observableFilters = new ArrayList<>(); + int position = 1; + for (Filter filter : filters) { + observableFilters.add(new ObservationFilter(this.registry, filter, position, size)); + position++; + } + return observableFilters; + } + + static AroundFilterObservation observation(HttpServletRequest request) { + return (AroundFilterObservation) request.getAttribute(ATTRIBUTE); + } + + private static final class VirtualFilterChain implements FilterChain { + + private final FilterChain originalChain; + + private final List additionalFilters; + + private final int size; + + private int currentPosition = 0; + + private VirtualFilterChain(FilterChain chain, List additionalFilters) { + this.originalChain = chain; + this.additionalFilters = additionalFilters; + this.size = additionalFilters.size(); + } + + @Override + public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException { + if (this.currentPosition == this.size) { + this.originalChain.doFilter(request, response); + return; + } + this.currentPosition++; + ObservationFilter nextFilter = this.additionalFilters.get(this.currentPosition - 1); + if (logger.isTraceEnabled()) { + String name = nextFilter.getName(); + logger.trace(LogMessage.format("Invoking %s (%d/%d)", name, this.currentPosition, this.size)); + } + nextFilter.doFilter(request, response, this); + } + + } + + static final class ObservationFilter implements Filter { + + private final ObservationRegistry registry; + + private final FilterChainObservationConvention convention = new FilterChainObservationConvention(); + + private final Filter filter; + + private final String name; + + private final int position; + + private final int size; + + ObservationFilter(ObservationRegistry registry, Filter filter, int position, int size) { + this.registry = registry; + this.filter = filter; + this.name = filter.getClass().getSimpleName(); + this.position = position; + this.size = size; + } + + String getName() { + return this.name; + } + + @Override + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) + throws IOException, ServletException { + if (this.position == 1) { + AroundFilterObservation parent = parent((HttpServletRequest) request); + parent.wrap(this::wrapFilter).doFilter(request, response, chain); + } + else { + wrapFilter(request, response, chain); + } + } + + private void wrapFilter(ServletRequest request, ServletResponse response, FilterChain chain) + throws IOException, ServletException { + AroundFilterObservation parent = observation((HttpServletRequest) request); + FilterChainObservationContext parentBefore = (FilterChainObservationContext) parent.before().getContext(); + parentBefore.setChainSize(this.size); + parentBefore.setFilterName(this.name); + parentBefore.setChainPosition(this.position); + this.filter.doFilter(request, response, chain); + parent.start(); + FilterChainObservationContext parentAfter = (FilterChainObservationContext) parent.after().getContext(); + parentAfter.setChainSize(this.size); + parentAfter.setFilterName(this.name); + parentAfter.setChainPosition(this.size - this.position + 1); + } + + private AroundFilterObservation parent(HttpServletRequest request) { + FilterChainObservationContext beforeContext = FilterChainObservationContext.before(request); + FilterChainObservationContext afterContext = FilterChainObservationContext.after(request); + Observation before = Observation.createNotStarted(this.convention, () -> beforeContext, this.registry); + Observation after = Observation.createNotStarted(this.convention, () -> afterContext, this.registry); + AroundFilterObservation parent = AroundFilterObservation.create(before, after); + request.setAttribute(ATTRIBUTE, parent); + return parent; + } + + } + + interface AroundFilterObservation extends FilterObservation { + + AroundFilterObservation NOOP = new AroundFilterObservation() { + }; + + static AroundFilterObservation create(Observation before, Observation after) { + if (before.isNoop() || after.isNoop()) { + return NOOP; + } + return new SimpleAroundFilterObservation(before, after); + } + + default Observation before() { + return Observation.NOOP; + } + + default Observation after() { + return Observation.NOOP; + } + + class SimpleAroundFilterObservation implements AroundFilterObservation { + + private final Iterator observations; + + private final Observation before; + + private final Observation after; + + private final AtomicReference currentScope = new AtomicReference<>(null); + + SimpleAroundFilterObservation(Observation before, Observation after) { + this.before = before; + this.after = after; + this.observations = Arrays.asList(before, after).iterator(); + } + + @Override + public void start() { + if (this.observations.hasNext()) { + stop(); + Observation observation = this.observations.next(); + observation.start(); + Observation.Scope scope = observation.openScope(); + this.currentScope.set(scope); + } + } + + @Override + public void error(Throwable ex) { + Observation.Scope scope = this.currentScope.get(); + if (scope == null) { + return; + } + scope.close(); + scope.getCurrentObservation().error(ex); + } + + @Override + public void stop() { + Observation.Scope scope = this.currentScope.getAndSet(null); + if (scope == null) { + return; + } + scope.close(); + scope.getCurrentObservation().stop(); + } + + @Override + public Filter wrap(Filter filter) { + return (request, response, chain) -> { + start(); + try { + filter.doFilter(request, response, chain); + } + catch (Throwable ex) { + error(ex); + throw ex; + } + finally { + stop(); + } + }; + } + + @Override + public FilterChain wrap(FilterChain chain) { + return (request, response) -> { + stop(); + try { + chain.doFilter(request, response); + } + finally { + start(); + } + }; + } + + @Override + public Observation before() { + return this.before; + } + + @Override + public Observation after() { + return this.after; + } + + } + + } + + interface FilterObservation { + + FilterObservation NOOP = new FilterObservation() { + }; + + static FilterObservation create(Observation observation) { + if (observation.isNoop()) { + return NOOP; + } + return new SimpleFilterObservation(observation); + } + + default void start() { + } + + default void error(Throwable ex) { + } + + default void stop() { + } + + default Filter wrap(Filter filter) { + return filter; + } + + default FilterChain wrap(FilterChain chain) { + return chain; + } + + class SimpleFilterObservation implements FilterObservation { + + private final Observation observation; + + SimpleFilterObservation(Observation observation) { + this.observation = observation; + } + + @Override + public void start() { + this.observation.start(); + } + + @Override + public void error(Throwable ex) { + this.observation.error(ex); + } + + @Override + public void stop() { + this.observation.stop(); + } + + @Override + public Filter wrap(Filter filter) { + if (this.observation.isNoop()) { + return filter; + } + return (request, response, chain) -> { + this.observation.start(); + try (Observation.Scope scope = this.observation.openScope()) { + filter.doFilter(request, response, chain); + } + catch (Throwable ex) { + this.observation.error(ex); + throw ex; + } + finally { + this.observation.stop(); + } + }; + } + + @Override + public FilterChain wrap(FilterChain chain) { + if (this.observation.isNoop()) { + return chain; + } + return (request, response) -> { + this.observation.start(); + try (Observation.Scope scope = this.observation.openScope()) { + chain.doFilter(request, response); + } + catch (Throwable ex) { + this.observation.error(ex); + throw ex; + } + finally { + this.observation.stop(); + } + }; + } + + } + + } + + static final class FilterChainObservationContext extends Observation.Context { + + private final ServletRequest request; + + private final String filterSection; + + private String filterName; + + private int chainPosition; + + private int chainSize; + + private FilterChainObservationContext(ServletRequest request, String filterSection) { + this.filterSection = filterSection; + this.request = request; + } + + static FilterChainObservationContext before(ServletRequest request) { + return new FilterChainObservationContext(request, "before"); + } + + static FilterChainObservationContext after(ServletRequest request) { + return new FilterChainObservationContext(request, "after"); + } + + @Override + public void setName(String name) { + super.setName(name); + if (name != null) { + setContextualName(name + "." + this.filterSection); + } + } + + String getRequestLine() { + return requestLine((HttpServletRequest) this.request); + } + + String getFilterSection() { + return this.filterSection; + } + + String getFilterName() { + return this.filterName; + } + + void setFilterName(String filterName) { + this.filterName = filterName; + } + + int getChainPosition() { + return this.chainPosition; + } + + void setChainPosition(int chainPosition) { + this.chainPosition = chainPosition; + } + + int getChainSize() { + return this.chainSize; + } + + void setChainSize(int chainSize) { + this.chainSize = chainSize; + } + + private static String requestLine(HttpServletRequest request) { + return request.getMethod() + " " + UrlUtils.buildRequestUrl(request); + } + + } + + static final class FilterChainObservationConvention + implements ObservationConvention { + + static final String CHAIN_OBSERVATION_NAME = "spring.security.http.chains"; + + private static final String REQUEST_LINE_NAME = "request.line"; + + private static final String CHAIN_POSITION_NAME = "chain.position"; + + private static final String CHAIN_SIZE_NAME = "chain.size"; + + private static final String FILTER_SECTION_NAME = "filter.section"; + + private static final String FILTER_NAME = "current.filter.name"; + + @Override + public String getName() { + return CHAIN_OBSERVATION_NAME; + } + + @Override + public KeyValues getLowCardinalityKeyValues(FilterChainObservationContext context) { + KeyValues kv = KeyValues.of(CHAIN_SIZE_NAME, String.valueOf(context.getChainSize())) + .and(CHAIN_POSITION_NAME, String.valueOf(context.getChainPosition())) + .and(FILTER_SECTION_NAME, context.getFilterSection()); + if (context.getFilterName() != null) { + kv = kv.and(FILTER_NAME, context.getFilterName()); + } + return kv; + } + + @Override + public KeyValues getHighCardinalityKeyValues(FilterChainObservationContext context) { + String requestLine = context.getRequestLine(); + return KeyValues.of(REQUEST_LINE_NAME, requestLine); + } + + @Override + public boolean supportsContext(Observation.Context context) { + return context instanceof FilterChainObservationContext; + } + + } + +} diff --git a/web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java b/web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java new file mode 100644 index 0000000000..c815dae5a0 --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java @@ -0,0 +1,531 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.server; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.ListIterator; +import java.util.concurrent.atomic.AtomicReference; + +import io.micrometer.common.KeyValues; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationConvention; +import io.micrometer.observation.ObservationRegistry; +import reactor.core.publisher.Mono; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilter; +import org.springframework.web.server.WebFilterChain; +import org.springframework.web.server.WebHandler; + +/** + * A + * {@link org.springframework.security.web.server.WebFilterChainProxy.WebFilterChainDecorator} + * that wraps the chain in before and after observations + * + * @author Josh Cummings + * @since 6.0 + */ +public final class ObservationWebFilterChainDecorator implements WebFilterChainProxy.WebFilterChainDecorator { + + private static final String ATTRIBUTE = ObservationWebFilterChainDecorator.class + ".observation"; + + static final String UNSECURED_OBSERVATION_NAME = "spring.security.http.unsecured.requests"; + + static final String SECURED_OBSERVATION_NAME = "spring.security.http.secured.requests"; + + private final ObservationRegistry registry; + + public ObservationWebFilterChainDecorator(ObservationRegistry registry) { + this.registry = registry; + } + + @Override + public WebFilterChain decorate(WebFilterChain original) { + return wrapUnsecured(original); + } + + @Override + public WebFilterChain decorate(WebFilterChain original, List filters) { + return new ObservationWebFilterChain(wrapSecured(original)::filter, wrap(filters)); + } + + private static AroundWebFilterObservation observation(ServerWebExchange exchange) { + return exchange.getAttribute(ATTRIBUTE); + } + + private WebFilterChain wrapSecured(WebFilterChain original) { + return (exchange) -> { + AroundWebFilterObservation parent = observation(exchange); + Observation observation = Observation.createNotStarted(SECURED_OBSERVATION_NAME, this.registry); + return parent.wrap(WebFilterObservation.create(observation).wrap(original)).filter(exchange); + }; + } + + private WebFilterChain wrapUnsecured(WebFilterChain original) { + return (exchange) -> { + Observation observation = Observation.createNotStarted(UNSECURED_OBSERVATION_NAME, this.registry); + return WebFilterObservation.create(observation).wrap(original).filter(exchange); + }; + } + + private List wrap(List filters) { + int size = filters.size(); + List observableFilters = new ArrayList<>(); + int position = 1; + for (WebFilter filter : filters) { + observableFilters.add(new ObservationWebFilter(this.registry, filter, position, size)); + position++; + } + return observableFilters; + } + + static class ObservationWebFilterChain implements WebFilterChain { + + private final WebHandler handler; + + @Nullable + private final ObservationWebFilter currentFilter; + + @Nullable + private final ObservationWebFilterChain chain; + + /** + * Public constructor with the list of filters and the target handler to use. + * @param handler the target handler + * @param filters the filters ahead of the handler + * @since 5.1 + */ + ObservationWebFilterChain(WebHandler handler, List filters) { + Assert.notNull(handler, "WebHandler is required"); + this.handler = handler; + ObservationWebFilterChain chain = initChain(filters, handler); + this.currentFilter = chain.currentFilter; + this.chain = chain.chain; + } + + private static ObservationWebFilterChain initChain(List filters, WebHandler handler) { + ObservationWebFilterChain chain = new ObservationWebFilterChain(handler, null, null); + ListIterator iterator = filters.listIterator(filters.size()); + while (iterator.hasPrevious()) { + chain = new ObservationWebFilterChain(handler, iterator.previous(), chain); + } + return chain; + } + + /** + * Private constructor to represent one link in the chain. + */ + private ObservationWebFilterChain(WebHandler handler, @Nullable ObservationWebFilter currentFilter, + @Nullable ObservationWebFilterChain chain) { + this.currentFilter = currentFilter; + this.handler = handler; + this.chain = chain; + } + + @Override + public Mono filter(ServerWebExchange exchange) { + return Mono.defer(() -> (this.currentFilter != null && this.chain != null) + ? invokeFilter(this.currentFilter, this.chain, exchange) : this.handler.handle(exchange)); + } + + private Mono invokeFilter(ObservationWebFilter current, ObservationWebFilterChain chain, + ServerWebExchange exchange) { + String currentName = current.getName(); + return current.filter(exchange, chain).checkpoint(currentName + " [DefaultWebFilterChain]"); + } + + } + + static final class ObservationWebFilter implements WebFilter { + + private final ObservationRegistry registry; + + private final WebFilterChainObservationConvention convention = new WebFilterChainObservationConvention(); + + private final WebFilter filter; + + private final String name; + + private final int position; + + private final int size; + + ObservationWebFilter(ObservationRegistry registry, WebFilter filter, int position, int size) { + this.registry = registry; + this.filter = filter; + this.name = filter.getClass().getSimpleName(); + this.position = position; + this.size = size; + } + + String getName() { + return this.name; + } + + @Override + public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { + if (this.position == 1) { + AroundWebFilterObservation parent = parent(exchange); + return parent.wrap(this::wrapFilter).filter(exchange, chain); + } + else { + return wrapFilter(exchange, chain); + } + } + + private Mono wrapFilter(ServerWebExchange exchange, WebFilterChain chain) { + AroundWebFilterObservation parent = observation(exchange); + WebFilterChainObservationContext parentBefore = (WebFilterChainObservationContext) parent.before() + .getContext(); + parentBefore.setChainSize(this.size); + parentBefore.setFilterName(this.name); + parentBefore.setChainPosition(this.position); + return this.filter.filter(exchange, chain).doOnSuccess((result) -> { + parent.start(); + WebFilterChainObservationContext parentAfter = (WebFilterChainObservationContext) parent.after() + .getContext(); + parentAfter.setChainSize(this.size); + parentAfter.setFilterName(this.name); + parentAfter.setChainPosition(this.size - this.position + 1); + }); + } + + private AroundWebFilterObservation parent(ServerWebExchange exchange) { + WebFilterChainObservationContext beforeContext = WebFilterChainObservationContext.before(exchange); + WebFilterChainObservationContext afterContext = WebFilterChainObservationContext.after(exchange); + Observation before = Observation.createNotStarted(this.convention, () -> beforeContext, this.registry); + Observation after = Observation.createNotStarted(this.convention, () -> afterContext, this.registry); + AroundWebFilterObservation parent = AroundWebFilterObservation.create(before, after); + exchange.getAttributes().put(ATTRIBUTE, parent); + return parent; + } + + } + + interface AroundWebFilterObservation extends WebFilterObservation { + + AroundWebFilterObservation NOOP = new AroundWebFilterObservation() { + }; + + static AroundWebFilterObservation create(Observation before, Observation after) { + if (before.isNoop() || after.isNoop()) { + return NOOP; + } + return new SimpleAroundWebFilterObservation(before, after); + } + + default Observation before() { + return Observation.NOOP; + } + + default Observation after() { + return Observation.NOOP; + } + + class SimpleAroundWebFilterObservation implements AroundWebFilterObservation { + + private final Iterator observations; + + private final Observation before; + + private final Observation after; + + private final AtomicReference currentObservation = new AtomicReference<>(null); + + SimpleAroundWebFilterObservation(Observation before, Observation after) { + this.before = before; + this.after = after; + this.observations = Arrays.asList(before, after).iterator(); + } + + @Override + public void start() { + if (this.observations.hasNext()) { + stop(); + Observation observation = this.observations.next(); + observation.start(); + this.currentObservation.set(observation); + } + } + + @Override + public void error(Throwable ex) { + Observation observation = this.currentObservation.get(); + if (observation == null) { + return; + } + observation.error(ex); + } + + @Override + public void stop() { + Observation observation = this.currentObservation.getAndSet(null); + if (observation == null) { + return; + } + observation.stop(); + } + + @Override + public WebFilterChain wrap(WebFilterChain chain) { + return (exchange) -> { + stop(); + // @formatter:off + return chain.filter(exchange) + .doOnSuccess((v) -> start()) + .doOnCancel(this::start) + .doOnError((t) -> { + error(t); + start(); + }); + // @formatter:on + }; + } + + @Override + public WebFilter wrap(WebFilter filter) { + return (exchange, chain) -> { + start(); + // @formatter:off + return filter.filter(exchange, chain) + .doOnSuccess((v) -> stop()) + .doOnCancel(this::stop) + .doOnError((t) -> { + error(t); + stop(); + }); + // @formatter:on + }; + } + + @Override + public Observation before() { + return this.before; + } + + @Override + public Observation after() { + return this.after; + } + + } + + } + + interface WebFilterObservation { + + WebFilterObservation NOOP = new WebFilterObservation() { + }; + + static WebFilterObservation create(Observation observation) { + if (observation.isNoop()) { + return NOOP; + } + return new SimpleWebFilterObservation(observation); + } + + default void start() { + } + + default void error(Throwable ex) { + } + + default void stop() { + } + + default WebFilter wrap(WebFilter filter) { + return filter; + } + + default WebFilterChain wrap(WebFilterChain chain) { + return chain; + } + + class SimpleWebFilterObservation implements WebFilterObservation { + + private final Observation observation; + + SimpleWebFilterObservation(Observation observation) { + this.observation = observation; + } + + @Override + public void start() { + this.observation.start(); + } + + @Override + public void error(Throwable ex) { + this.observation.error(ex); + } + + @Override + public void stop() { + this.observation.stop(); + } + + @Override + public WebFilter wrap(WebFilter filter) { + if (this.observation.isNoop()) { + return filter; + } + return (exchange, chain) -> { + this.observation.start(); + return filter.filter(exchange, chain).doOnSuccess((v) -> this.observation.stop()) + .doOnCancel(this.observation::stop).doOnError((t) -> { + this.observation.error(t); + this.observation.stop(); + }); + }; + } + + @Override + public WebFilterChain wrap(WebFilterChain chain) { + if (this.observation.isNoop()) { + return chain; + } + return (exchange) -> { + this.observation.start(); + return chain.filter(exchange).doOnSuccess((v) -> this.observation.stop()) + .doOnCancel(this.observation::stop).doOnError((t) -> { + this.observation.error(t); + this.observation.stop(); + }); + }; + } + + } + + } + + static final class WebFilterChainObservationContext extends Observation.Context { + + private final ServerWebExchange exchange; + + private final String filterSection; + + private String filterName; + + private int chainPosition; + + private int chainSize; + + private WebFilterChainObservationContext(ServerWebExchange exchange, String filterSection) { + this.exchange = exchange; + this.filterSection = filterSection; + } + + static WebFilterChainObservationContext before(ServerWebExchange exchange) { + return new WebFilterChainObservationContext(exchange, "before"); + } + + static WebFilterChainObservationContext after(ServerWebExchange exchange) { + return new WebFilterChainObservationContext(exchange, "after"); + } + + @Override + public void setName(String name) { + super.setName(name); + if (name != null) { + setContextualName(name + "." + this.filterSection); + } + } + + String getRequestLine() { + return this.exchange.getRequest().getPath().toString(); + } + + String getFilterSection() { + return this.filterSection; + } + + String getFilterName() { + return this.filterName; + } + + void setFilterName(String filterName) { + this.filterName = filterName; + } + + int getChainPosition() { + return this.chainPosition; + } + + void setChainPosition(int chainPosition) { + this.chainPosition = chainPosition; + } + + int getChainSize() { + return this.chainSize; + } + + void setChainSize(int chainSize) { + this.chainSize = chainSize; + } + + } + + static final class WebFilterChainObservationConvention + implements ObservationConvention { + + static final String CHAIN_OBSERVATION_NAME = "spring.security.http.chains"; + + private static final String REQUEST_LINE_NAME = "request.line"; + + private static final String CHAIN_POSITION_NAME = "chain.position"; + + private static final String CHAIN_SIZE_NAME = "chain.size"; + + private static final String FILTER_SECTION_NAME = "filter.section"; + + private static final String FILTER_NAME = "current.filter.name"; + + @Override + public String getName() { + return CHAIN_OBSERVATION_NAME; + } + + @Override + public KeyValues getLowCardinalityKeyValues(WebFilterChainObservationContext context) { + KeyValues kv = KeyValues.of(CHAIN_SIZE_NAME, String.valueOf(context.getChainSize())) + .and(CHAIN_POSITION_NAME, String.valueOf(context.getChainPosition())) + .and(FILTER_SECTION_NAME, context.getFilterSection()); + if (context.getFilterName() != null) { + kv = kv.and(FILTER_NAME, context.getFilterName()); + } + return kv; + } + + @Override + public KeyValues getHighCardinalityKeyValues(WebFilterChainObservationContext context) { + String requestLine = context.getRequestLine(); + return KeyValues.of(REQUEST_LINE_NAME, requestLine); + } + + @Override + public boolean supportsContext(Observation.Context context) { + return context instanceof WebFilterChainObservationContext; + } + + } + +} diff --git a/web/src/main/java/org/springframework/security/web/server/WebFilterChainProxy.java b/web/src/main/java/org/springframework/security/web/server/WebFilterChainProxy.java index 33704a096f..31a1156c5f 100644 --- a/web/src/main/java/org/springframework/security/web/server/WebFilterChainProxy.java +++ b/web/src/main/java/org/springframework/security/web/server/WebFilterChainProxy.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,11 +17,15 @@ package org.springframework.security.web.server; import java.util.Arrays; +import java.util.Collections; import java.util.List; +import jakarta.servlet.FilterChain; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import org.springframework.security.web.SecurityFilterChain; +import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; @@ -37,6 +41,8 @@ public class WebFilterChainProxy implements WebFilter { private final List filters; + private WebFilterChainDecorator filterChainDecorator = new DefaultWebFilterChainDecorator(); + public WebFilterChainProxy(List filters) { this.filters = filters; } @@ -49,10 +55,82 @@ public class WebFilterChainProxy implements WebFilter { public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { return Flux.fromIterable(this.filters) .filterWhen((securityWebFilterChain) -> securityWebFilterChain.matches(exchange)).next() - .switchIfEmpty(chain.filter(exchange).then(Mono.empty())) + .switchIfEmpty( + Mono.defer(() -> this.filterChainDecorator.decorate(chain).filter(exchange).then(Mono.empty()))) .flatMap((securityWebFilterChain) -> securityWebFilterChain.getWebFilters().collectList()) - .map((filters) -> new DefaultWebFilterChain(chain::filter, filters)) + .map((filters) -> this.filterChainDecorator.decorate(chain, filters)) .flatMap((securedChain) -> securedChain.filter(exchange)); } + /** + * Used to decorate the original {@link FilterChain} for each request + * + *

+ * By default, this decorates the filter chain with a {@link DefaultWebFilterChain} + * that iterates through security filters and then delegates to the original chain + * @param filterChainDecorator the strategy for constructing the filter chain + * @since 6.0 + */ + public void setFilterChainDecorator(WebFilterChainDecorator filterChainDecorator) { + Assert.notNull(filterChainDecorator, "filterChainDecorator cannot be null"); + this.filterChainDecorator = filterChainDecorator; + } + + /** + * A strategy for decorating the provided filter chain with one that accounts for the + * {@link SecurityFilterChain} for a given request. + * + * @author Josh Cummings + * @since 6.0 + */ + public interface WebFilterChainDecorator { + + /** + * Provide a new {@link FilterChain} that accounts for needed security + * considerations when there are no security filters. + * @param original the original {@link FilterChain} + * @return a security-enabled {@link FilterChain} + */ + default WebFilterChain decorate(WebFilterChain original) { + return decorate(original, Collections.emptyList()); + } + + /** + * Provide a new {@link FilterChain} that accounts for the provided filters as + * well as teh original filter chain. + * @param original the original {@link FilterChain} + * @param filters the security filters + * @return a security-enabled {@link FilterChain} that includes the provided + * filters + */ + WebFilterChain decorate(WebFilterChain original, List filters); + + } + + /** + * A {@link WebFilterChainDecorator} that uses the {@link DefaultWebFilterChain} + * + * @author Josh Cummings + * @since 6.0 + */ + public static class DefaultWebFilterChainDecorator implements WebFilterChainDecorator { + + /** + * {@inheritDoc} + */ + @Override + public WebFilterChain decorate(WebFilterChain original) { + return original; + } + + /** + * {@inheritDoc} + */ + @Override + public WebFilterChain decorate(WebFilterChain original, List filters) { + return new DefaultWebFilterChain(original::filter, filters); + } + + } + } diff --git a/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java b/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java index 16f6581152..34fe85c6b6 100644 --- a/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java +++ b/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java @@ -16,19 +16,27 @@ package org.springframework.security.web; +import java.io.IOException; import java.util.Arrays; import java.util.Collections; +import java.util.Iterator; import java.util.List; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationHandler; +import io.micrometer.observation.ObservationRegistry; import jakarta.servlet.Filter; import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequestWrapper; import jakarta.servlet.http.HttpServletResponse; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; import org.springframework.mock.web.MockHttpServletRequest; @@ -50,7 +58,9 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.willAnswer; import static org.mockito.BDDMockito.willThrow; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -270,10 +280,198 @@ public class FilterChainProxyTests { this.fcp.setRequestRejectedHandler(rjh); RequestRejectedException requestRejectedException = new RequestRejectedException("Contains illegal chars"); ServletException servletException = new ServletException(requestRejectedException); - given(fw.getFirewalledRequest(this.request)).willReturn(mock(FirewalledRequest.class)); + given(fw.getFirewalledRequest(this.request)).willReturn(new MockFirewalledRequest(this.request)); willThrow(servletException).given(this.chain).doFilter(any(), any()); this.fcp.doFilter(this.request, this.response, this.chain); verify(rjh).handle(eq(this.request), eq(this.response), eq((requestRejectedException))); } + @Test + public void doFilterWhenMatchesThenObservationRegistryObserves() throws Exception { + ObservationHandler handler = mock(ObservationHandler.class); + given(handler.supportsContext(any())).willReturn(true); + ObservationRegistry registry = ObservationRegistry.create(); + registry.observationConfig().observationHandler(handler); + given(this.matcher.matches(any())).willReturn(true); + SecurityFilterChain sec = new DefaultSecurityFilterChain(this.matcher, Arrays.asList(this.filter)); + FilterChainProxy fcp = new FilterChainProxy(sec); + fcp.setFilterChainDecorator(new ObservationFilterChainDecorator(registry)); + Filter filter = ObservationFilterChainDecorator.FilterObservation + .create(Observation.createNotStarted("wrap", registry)).wrap(fcp); + filter.doFilter(this.request, this.response, this.chain); + ArgumentCaptor captor = ArgumentCaptor.forClass(Observation.Context.class); + verify(handler, times(4)).onStart(captor.capture()); + verify(handler, times(4)).onStop(any()); + Iterator contexts = captor.getAllValues().iterator(); + assertThat(contexts.next().getName()).isEqualTo("wrap"); + assertFilterChainObservation(contexts.next(), "before", 1); + assertThat(contexts.next().getName()).isEqualTo(ObservationFilterChainDecorator.SECURED_OBSERVATION_NAME); + assertFilterChainObservation(contexts.next(), "after", 1); + } + + @Test + public void doFilterWhenMultipleFiltersThenObservationRegistryObserves() throws Exception { + ObservationHandler handler = mock(ObservationHandler.class); + given(handler.supportsContext(any())).willReturn(true); + ObservationRegistry registry = ObservationRegistry.create(); + registry.observationConfig().observationHandler(handler); + given(this.matcher.matches(any())).willReturn(true); + Filter one = mockFilter(); + Filter two = mockFilter(); + Filter three = mockFilter(); + SecurityFilterChain sec = new DefaultSecurityFilterChain(this.matcher, one, two, three); + FilterChainProxy fcp = new FilterChainProxy(sec); + fcp.setFilterChainDecorator(new ObservationFilterChainDecorator(registry)); + Filter filter = ObservationFilterChainDecorator.FilterObservation + .create(Observation.createNotStarted("wrap", registry)).wrap(fcp); + filter.doFilter(this.request, this.response, this.chain); + ArgumentCaptor captor = ArgumentCaptor.forClass(Observation.Context.class); + verify(handler, times(4)).onStart(captor.capture()); + verify(handler, times(4)).onStop(any()); + Iterator contexts = captor.getAllValues().iterator(); + assertThat(contexts.next().getName()).isEqualTo("wrap"); + assertFilterChainObservation(contexts.next(), "before", 3); + assertThat(contexts.next().getName()).isEqualTo(ObservationFilterChainDecorator.SECURED_OBSERVATION_NAME); + assertFilterChainObservation(contexts.next(), "after", 3); + } + + @Test + public void doFilterWhenMismatchesThenObservationRegistryObserves() throws Exception { + ObservationHandler handler = mock(ObservationHandler.class); + given(handler.supportsContext(any())).willReturn(true); + ObservationRegistry registry = ObservationRegistry.create(); + registry.observationConfig().observationHandler(handler); + SecurityFilterChain sec = new DefaultSecurityFilterChain(this.matcher, Arrays.asList(this.filter)); + FilterChainProxy fcp = new FilterChainProxy(sec); + fcp.setFilterChainDecorator(new ObservationFilterChainDecorator(registry)); + Filter filter = ObservationFilterChainDecorator.FilterObservation + .create(Observation.createNotStarted("wrap", registry)).wrap(fcp); + filter.doFilter(this.request, this.response, this.chain); + ArgumentCaptor captor = ArgumentCaptor.forClass(Observation.Context.class); + verify(handler, times(2)).onStart(captor.capture()); + verify(handler, times(2)).onStop(any()); + Iterator contexts = captor.getAllValues().iterator(); + assertThat(contexts.next().getName()).isEqualTo("wrap"); + assertThat(contexts.next().getName()).isEqualTo(ObservationFilterChainDecorator.UNSECURED_OBSERVATION_NAME); + } + + @Test + public void doFilterWhenFilterExceptionThenObservationRegistryObserves() throws Exception { + ObservationHandler handler = mock(ObservationHandler.class); + given(handler.supportsContext(any())).willReturn(true); + ObservationRegistry registry = ObservationRegistry.create(); + registry.observationConfig().observationHandler(handler); + willThrow(IllegalStateException.class).given(this.filter).doFilter(any(), any(), any()); + given(this.matcher.matches(any())).willReturn(true); + SecurityFilterChain sec = new DefaultSecurityFilterChain(this.matcher, Arrays.asList(this.filter)); + FilterChainProxy fcp = new FilterChainProxy(sec); + fcp.setFilterChainDecorator(new ObservationFilterChainDecorator(registry)); + Filter filter = ObservationFilterChainDecorator.FilterObservation + .create(Observation.createNotStarted("wrap", registry)).wrap(fcp); + assertThatExceptionOfType(IllegalStateException.class) + .isThrownBy(() -> filter.doFilter(this.request, this.response, this.chain)); + ArgumentCaptor captor = ArgumentCaptor.forClass(Observation.Context.class); + verify(handler, times(2)).onStart(captor.capture()); + verify(handler, times(2)).onStop(any()); + verify(handler, atLeastOnce()).onError(any()); + Iterator contexts = captor.getAllValues().iterator(); + assertThat(contexts.next().getName()).isEqualTo("wrap"); + assertFilterChainObservation(contexts.next(), "before", 1); + } + + @Test + public void doFilterWhenExceptionWithMultipleFiltersThenObservationRegistryObserves() throws Exception { + ObservationHandler handler = mock(ObservationHandler.class); + given(handler.supportsContext(any())).willReturn(true); + ObservationRegistry registry = ObservationRegistry.create(); + registry.observationConfig().observationHandler(handler); + given(this.matcher.matches(any())).willReturn(true); + Filter one = mockFilter(); + Filter two = mock(Filter.class); + willThrow(IllegalStateException.class).given(two).doFilter(any(), any(), any()); + Filter three = mockFilter(); + SecurityFilterChain sec = new DefaultSecurityFilterChain(this.matcher, one, two, three); + FilterChainProxy fcp = new FilterChainProxy(sec); + fcp.setFilterChainDecorator(new ObservationFilterChainDecorator(registry)); + Filter filter = ObservationFilterChainDecorator.FilterObservation + .create(Observation.createNotStarted("wrap", registry)).wrap(fcp); + assertThatExceptionOfType(IllegalStateException.class) + .isThrownBy(() -> filter.doFilter(this.request, this.response, this.chain)); + ArgumentCaptor captor = ArgumentCaptor.forClass(Observation.Context.class); + verify(handler, times(2)).onStart(captor.capture()); + verify(handler, times(2)).onStop(any()); + Iterator contexts = captor.getAllValues().iterator(); + assertThat(contexts.next().getName()).isEqualTo("wrap"); + assertFilterChainObservation(contexts.next(), "before", 2); + } + + @Test + public void doFilterWhenOneFilterDoesNotProceedThenObservationRegistryObserves() throws Exception { + ObservationHandler handler = mock(ObservationHandler.class); + given(handler.supportsContext(any())).willReturn(true); + ObservationRegistry registry = ObservationRegistry.create(); + registry.observationConfig().observationHandler(handler); + given(this.matcher.matches(any())).willReturn(true); + Filter one = mockFilter(); + Filter two = mock(Filter.class); + Filter three = mockFilter(); + SecurityFilterChain sec = new DefaultSecurityFilterChain(this.matcher, one, two, three); + FilterChainProxy fcp = new FilterChainProxy(sec); + fcp.setFilterChainDecorator(new ObservationFilterChainDecorator(registry)); + Filter filter = ObservationFilterChainDecorator.FilterObservation + .create(Observation.createNotStarted("wrap", registry)).wrap(fcp); + filter.doFilter(this.request, this.response, this.chain); + ArgumentCaptor captor = ArgumentCaptor.forClass(Observation.Context.class); + verify(handler, times(3)).onStart(captor.capture()); + verify(handler, times(3)).onStop(any()); + Iterator contexts = captor.getAllValues().iterator(); + assertThat(contexts.next().getName()).isEqualTo("wrap"); + assertFilterChainObservation(contexts.next(), "before", 2); + assertFilterChainObservation(contexts.next(), "after", 3); + } + + static void assertFilterChainObservation(Observation.Context context, String filterSection, int chainPosition) { + assertThat(context).isInstanceOf(ObservationFilterChainDecorator.FilterChainObservationContext.class); + ObservationFilterChainDecorator.FilterChainObservationContext filterChainObservationContext = (ObservationFilterChainDecorator.FilterChainObservationContext) context; + assertThat(context.getName()) + .isEqualTo(ObservationFilterChainDecorator.FilterChainObservationConvention.CHAIN_OBSERVATION_NAME); + assertThat(context.getContextualName()).endsWith(filterSection); + assertThat(filterChainObservationContext.getChainPosition()).isEqualTo(chainPosition); + } + + static Filter mockFilter() throws Exception { + Filter filter = mock(Filter.class); + willAnswer((invocation) -> { + HttpServletRequest request = invocation.getArgument(0, HttpServletRequest.class); + HttpServletResponse response = invocation.getArgument(1, HttpServletResponse.class); + FilterChain chain = invocation.getArgument(2, FilterChain.class); + chain.doFilter(request, response); + return null; + }).given(filter).doFilter(any(), any(), any()); + return filter; + } + + private static class MockFirewalledRequest extends FirewalledRequest { + + MockFirewalledRequest(HttpServletRequest request) { + super(request); + } + + @Override + public void reset() { + + } + + } + + private static class MockFilter implements Filter { + + @Override + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) + throws IOException, ServletException { + chain.doFilter(request, response); + } + + } + } diff --git a/web/src/test/java/org/springframework/security/web/server/WebFilterChainProxyTests.java b/web/src/test/java/org/springframework/security/web/server/WebFilterChainProxyTests.java index 4f527b6850..6c795cb640 100644 --- a/web/src/test/java/org/springframework/security/web/server/WebFilterChainProxyTests.java +++ b/web/src/test/java/org/springframework/security/web/server/WebFilterChainProxyTests.java @@ -17,12 +17,22 @@ package org.springframework.security.web.server; import java.util.Arrays; +import java.util.Iterator; import java.util.List; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationHandler; +import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import reactor.core.publisher.Mono; import org.springframework.http.HttpStatus; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.security.web.server.ObservationWebFilterChainDecorator.WebFilterChainObservationContext; +import org.springframework.security.web.server.ObservationWebFilterChainDecorator.WebFilterChainObservationConvention; +import org.springframework.security.web.server.ObservationWebFilterChainDecorator.WebFilterObservation; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult; import org.springframework.test.web.reactive.server.WebTestClient; @@ -30,6 +40,15 @@ import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + /** * @author Rob Winch * @since 5.0 @@ -47,6 +66,86 @@ public class WebFilterChainProxyTests { .isNotFound(); } + @Test + public void doFilterWhenMatchesThenObservationRegistryObserves() { + ObservationHandler handler = mock(ObservationHandler.class); + given(handler.supportsContext(any())).willReturn(true); + ObservationRegistry registry = ObservationRegistry.create(); + registry.observationConfig().observationHandler(handler); + List filters = Arrays.asList(new PassthroughWebFilter()); + ServerWebExchangeMatcher match = (exchange) -> MatchResult.match(); + MatcherSecurityWebFilterChain chain = new MatcherSecurityWebFilterChain(match, filters); + WebFilterChainProxy fcp = new WebFilterChainProxy(chain); + fcp.setFilterChainDecorator(new ObservationWebFilterChainDecorator(registry)); + WebFilter filter = WebFilterObservation.create(Observation.createNotStarted("wrap", registry)).wrap(fcp); + WebFilterChain mockChain = mock(WebFilterChain.class); + given(mockChain.filter(any())).willReturn(Mono.empty()); + filter.filter(MockServerWebExchange.from(MockServerHttpRequest.get("/")), mockChain).block(); + ArgumentCaptor captor = ArgumentCaptor.forClass(Observation.Context.class); + verify(handler, times(4)).onStart(captor.capture()); + Iterator contexts = captor.getAllValues().iterator(); + assertThat(contexts.next().getName()).isEqualTo("wrap"); + assertFilterChainObservation(contexts.next(), "before", 1); + assertThat(contexts.next().getName()).isEqualTo(ObservationWebFilterChainDecorator.SECURED_OBSERVATION_NAME); + assertFilterChainObservation(contexts.next(), "after", 1); + } + + @Test + public void doFilterWhenMismatchesThenObservationRegistryObserves() { + ObservationHandler handler = mock(ObservationHandler.class); + given(handler.supportsContext(any())).willReturn(true); + ObservationRegistry registry = ObservationRegistry.create(); + registry.observationConfig().observationHandler(handler); + List filters = Arrays.asList(new PassthroughWebFilter()); + ServerWebExchangeMatcher notMatch = (exchange) -> MatchResult.notMatch(); + MatcherSecurityWebFilterChain chain = new MatcherSecurityWebFilterChain(notMatch, filters); + WebFilterChainProxy fcp = new WebFilterChainProxy(chain); + fcp.setFilterChainDecorator(new ObservationWebFilterChainDecorator(registry)); + WebFilter filter = WebFilterObservation.create(Observation.createNotStarted("wrap", registry)).wrap(fcp); + WebFilterChain mockChain = mock(WebFilterChain.class); + given(mockChain.filter(any())).willReturn(Mono.empty()); + filter.filter(MockServerWebExchange.from(MockServerHttpRequest.get("/")), mockChain).block(); + ArgumentCaptor captor = ArgumentCaptor.forClass(Observation.Context.class); + verify(handler, times(2)).onStart(captor.capture()); + Iterator contexts = captor.getAllValues().iterator(); + assertThat(contexts.next().getName()).isEqualTo("wrap"); + assertThat(contexts.next().getName()).isEqualTo(ObservationWebFilterChainDecorator.UNSECURED_OBSERVATION_NAME); + } + + @Test + public void doFilterWhenFilterExceptionThenObservationRegistryObserves() { + ObservationHandler handler = mock(ObservationHandler.class); + given(handler.supportsContext(any())).willReturn(true); + ObservationRegistry registry = ObservationRegistry.create(); + registry.observationConfig().observationHandler(handler); + WebFilter error = mock(WebFilter.class); + given(error.filter(any(), any())).willReturn(Mono.error(new IllegalStateException())); + List filters = Arrays.asList(error); + ServerWebExchangeMatcher match = (exchange) -> MatchResult.match(); + MatcherSecurityWebFilterChain chain = new MatcherSecurityWebFilterChain(match, filters); + WebFilterChainProxy fcp = new WebFilterChainProxy(chain); + fcp.setFilterChainDecorator(new ObservationWebFilterChainDecorator(registry)); + WebFilter filter = WebFilterObservation.create(Observation.createNotStarted("wrap", registry)).wrap(fcp); + WebFilterChain mockChain = mock(WebFilterChain.class); + given(mockChain.filter(any())).willReturn(Mono.empty()); + assertThatExceptionOfType(IllegalStateException.class).isThrownBy( + () -> filter.filter(MockServerWebExchange.from(MockServerHttpRequest.get("/")), mockChain).block()); + ArgumentCaptor captor = ArgumentCaptor.forClass(Observation.Context.class); + verify(handler, times(2)).onStart(captor.capture()); + verify(handler, atLeastOnce()).onError(any()); + Iterator contexts = captor.getAllValues().iterator(); + assertThat(contexts.next().getName()).isEqualTo("wrap"); + assertFilterChainObservation(contexts.next(), "before", 1); + } + + static void assertFilterChainObservation(Observation.Context context, String filterSection, int chainPosition) { + assertThat(context).isInstanceOf(WebFilterChainObservationContext.class); + WebFilterChainObservationContext filterChainObservationContext = (WebFilterChainObservationContext) context; + assertThat(context.getName()).isEqualTo(WebFilterChainObservationConvention.CHAIN_OBSERVATION_NAME); + assertThat(context.getContextualName()).endsWith(filterSection); + assertThat(filterChainObservationContext.getChainPosition()).isEqualTo(chainPosition); + } + static class Http200WebFilter implements WebFilter { @Override @@ -56,4 +155,13 @@ public class WebFilterChainProxyTests { } + static class PassthroughWebFilter implements WebFilter { + + @Override + public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { + return chain.filter(exchange); + } + + } + }