Instrument Filter Chain

Closes gh-11911
This commit is contained in:
Josh Cummings 2022-09-21 18:04:31 -06:00
parent 8c610684f3
commit 99a87179dd
No known key found for this signature in database
GPG Key ID: A306A51F43B8E5A5
12 changed files with 1798 additions and 21 deletions

View File

@ -19,6 +19,7 @@ package org.springframework.security.config.annotation.web.builders;
import java.util.ArrayList;
import java.util.List;
import io.micrometer.observation.ObservationRegistry;
import jakarta.servlet.Filter;
import jakarta.servlet.ServletContext;
import jakarta.servlet.http.HttpServletRequest;
@ -45,6 +46,7 @@ import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.web.DefaultSecurityFilterChain;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.FilterInvocation;
import org.springframework.security.web.ObservationFilterChainDecorator;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.access.AuthorizationManagerWebInvocationPrivilegeEvaluator;
import org.springframework.security.web.access.DefaultWebInvocationPrivilegeEvaluator;
@ -101,6 +103,8 @@ public final class WebSecurity extends AbstractConfiguredSecurityBuilder<Filter,
private WebInvocationPrivilegeEvaluator privilegeEvaluator;
private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
private DefaultWebSecurityExpressionHandler defaultWebSecurityExpressionHandler = new DefaultWebSecurityExpressionHandler();
private SecurityExpressionHandler<FilterInvocation> expressionHandler = this.defaultWebSecurityExpressionHandler;
@ -303,6 +307,7 @@ public final class WebSecurity extends AbstractConfiguredSecurityBuilder<Filter,
if (this.requestRejectedHandler != null) {
filterChainProxy.setRequestRejectedHandler(this.requestRejectedHandler);
}
filterChainProxy.setFilterChainDecorator(getFilterChainDecorator());
filterChainProxy.afterPropertiesSet();
Filter result = filterChainProxy;
@ -366,6 +371,11 @@ public final class WebSecurity extends AbstractConfiguredSecurityBuilder<Filter,
}
catch (NoSuchBeanDefinitionException ex) {
}
try {
this.observationRegistry = applicationContext.getBean(ObservationRegistry.class);
}
catch (NoSuchBeanDefinitionException ex) {
}
}
@Override
@ -373,6 +383,13 @@ public final class WebSecurity extends AbstractConfiguredSecurityBuilder<Filter,
this.servletContext = servletContext;
}
FilterChainProxy.FilterChainDecorator getFilterChainDecorator() {
if (this.observationRegistry.isNoop()) {
return new FilterChainProxy.VirtualFilterChainDecorator();
}
return new ObservationFilterChainDecorator(this.observationRegistry);
}
/**
* Allows registering {@link RequestMatcher} instances that should be ignored by
* Spring Security.

View File

@ -19,6 +19,8 @@ package org.springframework.security.config.annotation.web.reactive;
import java.util.Arrays;
import java.util.List;
import io.micrometer.observation.ObservationRegistry;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.context.ApplicationContext;
@ -28,6 +30,7 @@ import org.springframework.core.annotation.Order;
import org.springframework.security.config.crypto.RsaKeyConversionServicePostProcessor;
import org.springframework.security.config.web.server.ServerHttpSecurity;
import org.springframework.security.web.reactive.result.view.CsrfRequestDataValueProcessor;
import org.springframework.security.web.server.ObservationWebFilterChainDecorator;
import org.springframework.security.web.server.SecurityWebFilterChain;
import org.springframework.security.web.server.WebFilterChainProxy;
import org.springframework.util.ClassUtils;
@ -55,6 +58,8 @@ class WebFluxSecurityConfiguration {
private List<SecurityWebFilterChain> securityWebFilterChains;
private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
@Autowired
ApplicationContext context;
@ -63,10 +68,19 @@ class WebFluxSecurityConfiguration {
this.securityWebFilterChains = securityWebFilterChains;
}
@Autowired(required = false)
void setObservationRegistry(ObservationRegistry observationRegistry) {
this.observationRegistry = observationRegistry;
}
@Bean(SPRING_SECURITY_WEBFILTERCHAINFILTER_BEAN_NAME)
@Order(WEB_FILTER_CHAIN_FILTER_ORDER)
WebFilterChainProxy springSecurityWebFilterChainFilter() {
return new WebFilterChainProxy(getSecurityWebFilterChains());
WebFilterChainProxy proxy = new WebFilterChainProxy(getSecurityWebFilterChains());
if (!this.observationRegistry.isNoop()) {
proxy.setFilterChainDecorator(new ObservationWebFilterChainDecorator(this.observationRegistry));
}
return proxy;
}
@Bean(name = AbstractView.REQUEST_DATA_VALUE_PROCESSOR_BEAN_NAME)

View File

@ -56,6 +56,7 @@ import org.springframework.security.config.Elements;
import org.springframework.security.config.authentication.AuthenticationManagerFactoryBean;
import org.springframework.security.web.DefaultSecurityFilterChain;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.ObservationFilterChainDecorator;
import org.springframework.security.web.PortResolverImpl;
import org.springframework.security.web.util.matcher.AnyRequestMatcher;
import org.springframework.util.StringUtils;
@ -363,6 +364,10 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
fcpBldr.getRawBeanDefinition().setSource(source);
fcpBldr.addConstructorArgReference(BeanIds.FILTER_CHAINS);
fcpBldr.addPropertyValue("filterChainValidator", new RootBeanDefinition(DefaultFilterChainValidator.class));
BeanDefinition filterChainDecorator = BeanDefinitionBuilder
.rootBeanDefinition(FilterChainDecoratorFactory.class)
.addPropertyValue("observationRegistry", getObservationRegistry(element)).getBeanDefinition();
fcpBldr.addPropertyValue("filterChainDecorator", filterChainDecorator);
BeanDefinition fcpBean = fcpBldr.getBeanDefinition();
pc.registerBeanComponent(new BeanComponentDefinition(fcpBean, BeanIds.FILTER_CHAIN_PROXY));
registry.registerAlias(BeanIds.FILTER_CHAIN_PROXY, BeanIds.SPRING_SECURITY_FILTER_CHAIN);
@ -509,4 +514,28 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
}
public static final class FilterChainDecoratorFactory
implements FactoryBean<FilterChainProxy.FilterChainDecorator> {
private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
@Override
public FilterChainProxy.FilterChainDecorator getObject() throws Exception {
if (this.observationRegistry.isNoop()) {
return new FilterChainProxy.VirtualFilterChainDecorator();
}
return new ObservationFilterChainDecorator(this.observationRegistry);
}
@Override
public Class<?> getObjectType() {
return FilterChainProxy.FilterChainDecorator.class;
}
public void setObservationRegistry(ObservationRegistry registry) {
this.observationRegistry = registry;
}
}
}

View File

@ -0,0 +1,115 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.config.annotation.web.configurers;
import java.util.Iterator;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationHandler;
import io.micrometer.observation.ObservationRegistry;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.test.SpringTestContext;
import org.springframework.security.config.test.SpringTestContextExtension;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.provisioning.InMemoryUserDetailsManager;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.test.web.servlet.MockMvc;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.springframework.security.config.Customizer.withDefaults;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
/**
* @author Josh Cummings
*
*/
@ExtendWith(SpringTestContextExtension.class)
public class HttpSecurityObservationTests {
@Autowired
MockMvc mvc;
public final SpringTestContext spring = new SpringTestContext(this);
@Test
public void getWhenUsingObservationRegistryThenObservesRequest() throws Exception {
this.spring.register(ObservationRegistryConfig.class).autowire();
// @formatter:off
this.mvc.perform(get("/").with(httpBasic("user", "password")))
.andExpect(status().isNotFound());
// @formatter:on
ObservationHandler<Observation.Context> handler = this.spring.getContext().getBean(ObservationHandler.class);
ArgumentCaptor<Observation.Context> captor = ArgumentCaptor.forClass(Observation.Context.class);
verify(handler, times(5)).onStart(captor.capture());
Iterator<Observation.Context> contexts = captor.getAllValues().iterator();
assertThat(contexts.next().getContextualName()).isEqualTo("spring.security.http.chains.before");
assertThat(contexts.next().getName()).isEqualTo("spring.security.authentications");
assertThat(contexts.next().getName()).isEqualTo("spring.security.authorizations");
assertThat(contexts.next().getName()).isEqualTo("spring.security.http.secured.requests");
assertThat(contexts.next().getContextualName()).isEqualTo("spring.security.http.chains.after");
}
@EnableWebSecurity
@Configuration
static class ObservationRegistryConfig {
private ObservationHandler<Observation.Context> handler = mock(ObservationHandler.class);
@Bean
SecurityFilterChain app(HttpSecurity http) throws Exception {
http.httpBasic(withDefaults()).authorizeHttpRequests((requests) -> requests.anyRequest().authenticated());
return http.build();
}
@Bean
UserDetailsService userDetailsService() {
return new InMemoryUserDetailsManager(
User.withDefaultPasswordEncoder().username("user").password("password").authorities("app").build());
}
@Bean
ObservationHandler<Observation.Context> observationHandler() {
return this.handler;
}
@Bean
ObservationRegistry observationRegistry() {
given(this.handler.supportsContext(any())).willReturn(true);
ObservationRegistry registry = ObservationRegistry.create();
registry.observationConfig().observationHandler(this.handler);
return registry;
}
}
}

View File

@ -16,13 +16,20 @@
package org.springframework.security.config.http;
import java.util.Iterator;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationHandler;
import io.micrometer.observation.ObservationRegistry;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpServletResponseWrapper;
import org.apache.http.HttpStatus;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
@ -36,7 +43,10 @@ import org.springframework.test.web.servlet.MockMvc;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
@ -101,6 +111,24 @@ public class HttpConfigTests {
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/login");
}
@Test
public void getWhenUsingObservationRegistryThenObservesRequest() throws Exception {
this.spring.configLocations(this.xml("WithObservationRegistry")).autowire();
// @formatter:off
this.mvc.perform(get("/").with(httpBasic("user", "password")))
.andExpect(status().isNotFound());
// @formatter:on
ObservationHandler<Observation.Context> handler = this.spring.getContext().getBean(ObservationHandler.class);
ArgumentCaptor<Observation.Context> captor = ArgumentCaptor.forClass(Observation.Context.class);
verify(handler, times(5)).onStart(captor.capture());
Iterator<Observation.Context> contexts = captor.getAllValues().iterator();
assertThat(contexts.next().getContextualName()).isEqualTo("spring.security.http.chains.before");
assertThat(contexts.next().getName()).isEqualTo("spring.security.authentications");
assertThat(contexts.next().getName()).isEqualTo("spring.security.authorizations");
assertThat(contexts.next().getName()).isEqualTo("spring.security.http.secured.requests");
assertThat(contexts.next().getContextualName()).isEqualTo("spring.security.http.chains.after");
}
private String xml(String configName) {
return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml";
}
@ -133,4 +161,27 @@ public class HttpConfigTests {
}
public static final class MockObservationRegistry implements FactoryBean<ObservationRegistry> {
private ObservationHandler<Observation.Context> handler = mock(ObservationHandler.class);
@Override
public ObservationRegistry getObject() {
ObservationRegistry registry = ObservationRegistry.create();
registry.observationConfig().observationHandler(this.handler);
given(this.handler.supportsContext(any())).willReturn(true);
return registry;
}
@Override
public Class<?> getObjectType() {
return ObservationRegistry.class;
}
public void setHandler(ObservationHandler<Observation.Context> handler) {
this.handler = handler;
}
}
}

View File

@ -0,0 +1,40 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ Copyright 2002-2018 the original author or authors.
~
~ Licensed under the Apache License, Version 2.0 (the "License");
~ you may not use this file except in compliance with the License.
~ You may obtain a copy of the License at
~
~ https://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS,
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->
<b:beans xmlns:b="http://www.springframework.org/schema/beans"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns="http://www.springframework.org/schema/security"
xsi:schemaLocation="
http://www.springframework.org/schema/security
https://www.springframework.org/schema/security/spring-security.xsd
http://www.springframework.org/schema/beans
https://www.springframework.org/schema/beans/spring-beans.xsd">
<http auto-config="true" observation-registry-ref="ref" use-authorization-manager="true">
<intercept-url pattern="/**" access="hasRole('USER')"/>
</http>
<b:bean name="handler" class="org.mockito.Mockito" factory-method="mock">
<b:constructor-arg value="io.micrometer.observation.ObservationHandler"/>
</b:bean>
<b:bean name="ref" class="org.springframework.security.config.http.HttpConfigTests.MockObservationRegistry">
<b:property name="handler" ref="handler"/>
</b:bean>
<b:import resource="userservice.xml"/>
</b:beans>

View File

@ -160,6 +160,8 @@ public class FilterChainProxy extends GenericFilterBean {
private ThrowableAnalyzer throwableAnalyzer = new ThrowableAnalyzer();
private FilterChainDecorator filterChainDecorator = new VirtualFilterChainDecorator();
public FilterChainProxy() {
}
@ -214,14 +216,21 @@ public class FilterChainProxy extends GenericFilterBean {
logger.trace(LogMessage.of(() -> "No security for " + requestLine(firewallRequest)));
}
firewallRequest.reset();
chain.doFilter(firewallRequest, firewallResponse);
this.filterChainDecorator.decorate(chain).doFilter(firewallRequest, firewallResponse);
return;
}
if (logger.isDebugEnabled()) {
logger.debug(LogMessage.of(() -> "Securing " + requestLine(firewallRequest)));
}
VirtualFilterChain virtualFilterChain = new VirtualFilterChain(firewallRequest, chain, filters);
virtualFilterChain.doFilter(firewallRequest, firewallResponse);
FilterChain reset = (req, res) -> {
if (logger.isDebugEnabled()) {
logger.debug(LogMessage.of(() -> "Secured " + requestLine(firewallRequest)));
}
// Deactivate path stripping as we exit the security filter chain
firewallRequest.reset();
chain.doFilter(req, res);
};
this.filterChainDecorator.decorate(reset, filters).doFilter(firewallRequest, firewallResponse);
}
/**
@ -249,7 +258,7 @@ public class FilterChainProxy extends GenericFilterBean {
* @return matching filter list
*/
public List<Filter> getFilters(String url) {
return getFilters(this.firewall.getFirewalledRequest((new FilterInvocation(url, "GET").getRequest())));
return getFilters(this.firewall.getFirewalledRequest(new FilterInvocation(url, "GET").getRequest()));
}
/**
@ -281,6 +290,20 @@ public class FilterChainProxy extends GenericFilterBean {
this.filterChainValidator = filterChainValidator;
}
/**
* Used to decorate the original {@link FilterChain} for each request
*
* <p>
* By default, this decorates the filter chain with a {@link VirtualFilterChain} that
* iterates through security filters and then delegates to the original chain
* @param filterChainDecorator the strategy for constructing the filter chain
* @since 6.0
*/
public void setFilterChainDecorator(FilterChainDecorator filterChainDecorator) {
Assert.notNull(filterChainDecorator, "filterChainDecorator cannot be null");
this.filterChainDecorator = filterChainDecorator;
}
/**
* Sets the "firewall" implementation which will be used to validate and wrap (or
* potentially reject) the incoming requests. The default implementation should be
@ -326,36 +349,27 @@ public class FilterChainProxy extends GenericFilterBean {
private final List<Filter> additionalFilters;
private final FirewalledRequest firewalledRequest;
private final int size;
private int currentPosition = 0;
private VirtualFilterChain(FirewalledRequest firewalledRequest, FilterChain chain,
List<Filter> additionalFilters) {
private VirtualFilterChain(FilterChain chain, List<Filter> additionalFilters) {
this.originalChain = chain;
this.additionalFilters = additionalFilters;
this.size = additionalFilters.size();
this.firewalledRequest = firewalledRequest;
}
@Override
public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException {
if (this.currentPosition == this.size) {
if (logger.isDebugEnabled()) {
logger.debug(LogMessage.of(() -> "Secured " + requestLine(this.firewalledRequest)));
}
// Deactivate path stripping as we exit the security filter chain
this.firewalledRequest.reset();
this.originalChain.doFilter(request, response);
return;
}
this.currentPosition++;
Filter nextFilter = this.additionalFilters.get(this.currentPosition - 1);
if (logger.isTraceEnabled()) {
logger.trace(LogMessage.format("Invoking %s (%d/%d)", nextFilter.getClass().getSimpleName(),
this.currentPosition, this.size));
String name = nextFilter.getClass().getSimpleName();
logger.trace(LogMessage.format("Invoking %s (%d/%d)", name, this.currentPosition, this.size));
}
nextFilter.doFilter(request, response, this);
}
@ -376,4 +390,61 @@ public class FilterChainProxy extends GenericFilterBean {
}
/**
* A strategy for decorating the provided filter chain with one that accounts for the
* {@link SecurityFilterChain} for a given request.
*
* @author Josh Cummings
* @since 6.0
*/
public interface FilterChainDecorator {
/**
* Provide a new {@link FilterChain} that accounts for needed security
* considerations when there are no security filters.
* @param original the original {@link FilterChain}
* @return a security-enabled {@link FilterChain}
*/
default FilterChain decorate(FilterChain original) {
return decorate(original, Collections.emptyList());
}
/**
* Provide a new {@link FilterChain} that accounts for the provided filters as
* well as teh original filter chain.
* @param original the original {@link FilterChain}
* @param filters the security filters
* @return a security-enabled {@link FilterChain} that includes the provided
* filters
*/
FilterChain decorate(FilterChain original, List<Filter> filters);
}
/**
* A {@link FilterChainDecorator} that uses the {@link VirtualFilterChain}
*
* @author Josh Cummings
* @since 6.0
*/
public static final class VirtualFilterChainDecorator implements FilterChainDecorator {
/**
* {@inheritDoc}
*/
@Override
public FilterChain decorate(FilterChain original) {
return original;
}
/**
* {@inheritDoc}
*/
@Override
public FilterChain decorate(FilterChain original, List<Filter> filters) {
return new VirtualFilterChain(original, filters);
}
}
}

View File

@ -0,0 +1,525 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import io.micrometer.common.KeyValues;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationConvention;
import io.micrometer.observation.ObservationRegistry;
import jakarta.servlet.Filter;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.web.util.UrlUtils;
/**
* A {@link org.springframework.security.web.server.FilterChainProxy.FilterChainDecorator}
* that wraps the chain in before and after observations
*
* @author Josh Cummings
* @since 6.0
*/
public final class ObservationFilterChainDecorator implements FilterChainProxy.FilterChainDecorator {
private static final Log logger = LogFactory.getLog(FilterChainProxy.class);
private static final String ATTRIBUTE = ObservationFilterChainDecorator.class + ".observation";
static final String UNSECURED_OBSERVATION_NAME = "spring.security.http.unsecured.requests";
static final String SECURED_OBSERVATION_NAME = "spring.security.http.secured.requests";
private final ObservationRegistry registry;
public ObservationFilterChainDecorator(ObservationRegistry registry) {
this.registry = registry;
}
@Override
public FilterChain decorate(FilterChain original) {
return wrapUnsecured(original);
}
@Override
public FilterChain decorate(FilterChain original, List<Filter> filters) {
return new VirtualFilterChain(wrapSecured(original), wrap(filters));
}
private FilterChain wrapSecured(FilterChain original) {
return (req, res) -> {
AroundFilterObservation parent = observation((HttpServletRequest) req);
Observation observation = Observation.createNotStarted(SECURED_OBSERVATION_NAME, this.registry);
parent.wrap(FilterObservation.create(observation).wrap(original)).doFilter(req, res);
};
}
private FilterChain wrapUnsecured(FilterChain original) {
return (req, res) -> {
Observation observation = Observation.createNotStarted(UNSECURED_OBSERVATION_NAME, this.registry);
FilterObservation.create(observation).wrap(original).doFilter(req, res);
};
}
private List<ObservationFilter> wrap(List<Filter> filters) {
int size = filters.size();
List<ObservationFilter> observableFilters = new ArrayList<>();
int position = 1;
for (Filter filter : filters) {
observableFilters.add(new ObservationFilter(this.registry, filter, position, size));
position++;
}
return observableFilters;
}
static AroundFilterObservation observation(HttpServletRequest request) {
return (AroundFilterObservation) request.getAttribute(ATTRIBUTE);
}
private static final class VirtualFilterChain implements FilterChain {
private final FilterChain originalChain;
private final List<ObservationFilter> additionalFilters;
private final int size;
private int currentPosition = 0;
private VirtualFilterChain(FilterChain chain, List<ObservationFilter> additionalFilters) {
this.originalChain = chain;
this.additionalFilters = additionalFilters;
this.size = additionalFilters.size();
}
@Override
public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException {
if (this.currentPosition == this.size) {
this.originalChain.doFilter(request, response);
return;
}
this.currentPosition++;
ObservationFilter nextFilter = this.additionalFilters.get(this.currentPosition - 1);
if (logger.isTraceEnabled()) {
String name = nextFilter.getName();
logger.trace(LogMessage.format("Invoking %s (%d/%d)", name, this.currentPosition, this.size));
}
nextFilter.doFilter(request, response, this);
}
}
static final class ObservationFilter implements Filter {
private final ObservationRegistry registry;
private final FilterChainObservationConvention convention = new FilterChainObservationConvention();
private final Filter filter;
private final String name;
private final int position;
private final int size;
ObservationFilter(ObservationRegistry registry, Filter filter, int position, int size) {
this.registry = registry;
this.filter = filter;
this.name = filter.getClass().getSimpleName();
this.position = position;
this.size = size;
}
String getName() {
return this.name;
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
if (this.position == 1) {
AroundFilterObservation parent = parent((HttpServletRequest) request);
parent.wrap(this::wrapFilter).doFilter(request, response, chain);
}
else {
wrapFilter(request, response, chain);
}
}
private void wrapFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
AroundFilterObservation parent = observation((HttpServletRequest) request);
FilterChainObservationContext parentBefore = (FilterChainObservationContext) parent.before().getContext();
parentBefore.setChainSize(this.size);
parentBefore.setFilterName(this.name);
parentBefore.setChainPosition(this.position);
this.filter.doFilter(request, response, chain);
parent.start();
FilterChainObservationContext parentAfter = (FilterChainObservationContext) parent.after().getContext();
parentAfter.setChainSize(this.size);
parentAfter.setFilterName(this.name);
parentAfter.setChainPosition(this.size - this.position + 1);
}
private AroundFilterObservation parent(HttpServletRequest request) {
FilterChainObservationContext beforeContext = FilterChainObservationContext.before(request);
FilterChainObservationContext afterContext = FilterChainObservationContext.after(request);
Observation before = Observation.createNotStarted(this.convention, () -> beforeContext, this.registry);
Observation after = Observation.createNotStarted(this.convention, () -> afterContext, this.registry);
AroundFilterObservation parent = AroundFilterObservation.create(before, after);
request.setAttribute(ATTRIBUTE, parent);
return parent;
}
}
interface AroundFilterObservation extends FilterObservation {
AroundFilterObservation NOOP = new AroundFilterObservation() {
};
static AroundFilterObservation create(Observation before, Observation after) {
if (before.isNoop() || after.isNoop()) {
return NOOP;
}
return new SimpleAroundFilterObservation(before, after);
}
default Observation before() {
return Observation.NOOP;
}
default Observation after() {
return Observation.NOOP;
}
class SimpleAroundFilterObservation implements AroundFilterObservation {
private final Iterator<Observation> observations;
private final Observation before;
private final Observation after;
private final AtomicReference<Observation.Scope> currentScope = new AtomicReference<>(null);
SimpleAroundFilterObservation(Observation before, Observation after) {
this.before = before;
this.after = after;
this.observations = Arrays.asList(before, after).iterator();
}
@Override
public void start() {
if (this.observations.hasNext()) {
stop();
Observation observation = this.observations.next();
observation.start();
Observation.Scope scope = observation.openScope();
this.currentScope.set(scope);
}
}
@Override
public void error(Throwable ex) {
Observation.Scope scope = this.currentScope.get();
if (scope == null) {
return;
}
scope.close();
scope.getCurrentObservation().error(ex);
}
@Override
public void stop() {
Observation.Scope scope = this.currentScope.getAndSet(null);
if (scope == null) {
return;
}
scope.close();
scope.getCurrentObservation().stop();
}
@Override
public Filter wrap(Filter filter) {
return (request, response, chain) -> {
start();
try {
filter.doFilter(request, response, chain);
}
catch (Throwable ex) {
error(ex);
throw ex;
}
finally {
stop();
}
};
}
@Override
public FilterChain wrap(FilterChain chain) {
return (request, response) -> {
stop();
try {
chain.doFilter(request, response);
}
finally {
start();
}
};
}
@Override
public Observation before() {
return this.before;
}
@Override
public Observation after() {
return this.after;
}
}
}
interface FilterObservation {
FilterObservation NOOP = new FilterObservation() {
};
static FilterObservation create(Observation observation) {
if (observation.isNoop()) {
return NOOP;
}
return new SimpleFilterObservation(observation);
}
default void start() {
}
default void error(Throwable ex) {
}
default void stop() {
}
default Filter wrap(Filter filter) {
return filter;
}
default FilterChain wrap(FilterChain chain) {
return chain;
}
class SimpleFilterObservation implements FilterObservation {
private final Observation observation;
SimpleFilterObservation(Observation observation) {
this.observation = observation;
}
@Override
public void start() {
this.observation.start();
}
@Override
public void error(Throwable ex) {
this.observation.error(ex);
}
@Override
public void stop() {
this.observation.stop();
}
@Override
public Filter wrap(Filter filter) {
if (this.observation.isNoop()) {
return filter;
}
return (request, response, chain) -> {
this.observation.start();
try (Observation.Scope scope = this.observation.openScope()) {
filter.doFilter(request, response, chain);
}
catch (Throwable ex) {
this.observation.error(ex);
throw ex;
}
finally {
this.observation.stop();
}
};
}
@Override
public FilterChain wrap(FilterChain chain) {
if (this.observation.isNoop()) {
return chain;
}
return (request, response) -> {
this.observation.start();
try (Observation.Scope scope = this.observation.openScope()) {
chain.doFilter(request, response);
}
catch (Throwable ex) {
this.observation.error(ex);
throw ex;
}
finally {
this.observation.stop();
}
};
}
}
}
static final class FilterChainObservationContext extends Observation.Context {
private final ServletRequest request;
private final String filterSection;
private String filterName;
private int chainPosition;
private int chainSize;
private FilterChainObservationContext(ServletRequest request, String filterSection) {
this.filterSection = filterSection;
this.request = request;
}
static FilterChainObservationContext before(ServletRequest request) {
return new FilterChainObservationContext(request, "before");
}
static FilterChainObservationContext after(ServletRequest request) {
return new FilterChainObservationContext(request, "after");
}
@Override
public void setName(String name) {
super.setName(name);
if (name != null) {
setContextualName(name + "." + this.filterSection);
}
}
String getRequestLine() {
return requestLine((HttpServletRequest) this.request);
}
String getFilterSection() {
return this.filterSection;
}
String getFilterName() {
return this.filterName;
}
void setFilterName(String filterName) {
this.filterName = filterName;
}
int getChainPosition() {
return this.chainPosition;
}
void setChainPosition(int chainPosition) {
this.chainPosition = chainPosition;
}
int getChainSize() {
return this.chainSize;
}
void setChainSize(int chainSize) {
this.chainSize = chainSize;
}
private static String requestLine(HttpServletRequest request) {
return request.getMethod() + " " + UrlUtils.buildRequestUrl(request);
}
}
static final class FilterChainObservationConvention
implements ObservationConvention<FilterChainObservationContext> {
static final String CHAIN_OBSERVATION_NAME = "spring.security.http.chains";
private static final String REQUEST_LINE_NAME = "request.line";
private static final String CHAIN_POSITION_NAME = "chain.position";
private static final String CHAIN_SIZE_NAME = "chain.size";
private static final String FILTER_SECTION_NAME = "filter.section";
private static final String FILTER_NAME = "current.filter.name";
@Override
public String getName() {
return CHAIN_OBSERVATION_NAME;
}
@Override
public KeyValues getLowCardinalityKeyValues(FilterChainObservationContext context) {
KeyValues kv = KeyValues.of(CHAIN_SIZE_NAME, String.valueOf(context.getChainSize()))
.and(CHAIN_POSITION_NAME, String.valueOf(context.getChainPosition()))
.and(FILTER_SECTION_NAME, context.getFilterSection());
if (context.getFilterName() != null) {
kv = kv.and(FILTER_NAME, context.getFilterName());
}
return kv;
}
@Override
public KeyValues getHighCardinalityKeyValues(FilterChainObservationContext context) {
String requestLine = context.getRequestLine();
return KeyValues.of(REQUEST_LINE_NAME, requestLine);
}
@Override
public boolean supportsContext(Observation.Context context) {
return context instanceof FilterChainObservationContext;
}
}
}

View File

@ -0,0 +1,531 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.server;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.concurrent.atomic.AtomicReference;
import io.micrometer.common.KeyValues;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationConvention;
import io.micrometer.observation.ObservationRegistry;
import reactor.core.publisher.Mono;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import org.springframework.web.server.WebHandler;
/**
* A
* {@link org.springframework.security.web.server.WebFilterChainProxy.WebFilterChainDecorator}
* that wraps the chain in before and after observations
*
* @author Josh Cummings
* @since 6.0
*/
public final class ObservationWebFilterChainDecorator implements WebFilterChainProxy.WebFilterChainDecorator {
private static final String ATTRIBUTE = ObservationWebFilterChainDecorator.class + ".observation";
static final String UNSECURED_OBSERVATION_NAME = "spring.security.http.unsecured.requests";
static final String SECURED_OBSERVATION_NAME = "spring.security.http.secured.requests";
private final ObservationRegistry registry;
public ObservationWebFilterChainDecorator(ObservationRegistry registry) {
this.registry = registry;
}
@Override
public WebFilterChain decorate(WebFilterChain original) {
return wrapUnsecured(original);
}
@Override
public WebFilterChain decorate(WebFilterChain original, List<WebFilter> filters) {
return new ObservationWebFilterChain(wrapSecured(original)::filter, wrap(filters));
}
private static AroundWebFilterObservation observation(ServerWebExchange exchange) {
return exchange.getAttribute(ATTRIBUTE);
}
private WebFilterChain wrapSecured(WebFilterChain original) {
return (exchange) -> {
AroundWebFilterObservation parent = observation(exchange);
Observation observation = Observation.createNotStarted(SECURED_OBSERVATION_NAME, this.registry);
return parent.wrap(WebFilterObservation.create(observation).wrap(original)).filter(exchange);
};
}
private WebFilterChain wrapUnsecured(WebFilterChain original) {
return (exchange) -> {
Observation observation = Observation.createNotStarted(UNSECURED_OBSERVATION_NAME, this.registry);
return WebFilterObservation.create(observation).wrap(original).filter(exchange);
};
}
private List<ObservationWebFilter> wrap(List<WebFilter> filters) {
int size = filters.size();
List<ObservationWebFilter> observableFilters = new ArrayList<>();
int position = 1;
for (WebFilter filter : filters) {
observableFilters.add(new ObservationWebFilter(this.registry, filter, position, size));
position++;
}
return observableFilters;
}
static class ObservationWebFilterChain implements WebFilterChain {
private final WebHandler handler;
@Nullable
private final ObservationWebFilter currentFilter;
@Nullable
private final ObservationWebFilterChain chain;
/**
* Public constructor with the list of filters and the target handler to use.
* @param handler the target handler
* @param filters the filters ahead of the handler
* @since 5.1
*/
ObservationWebFilterChain(WebHandler handler, List<ObservationWebFilter> filters) {
Assert.notNull(handler, "WebHandler is required");
this.handler = handler;
ObservationWebFilterChain chain = initChain(filters, handler);
this.currentFilter = chain.currentFilter;
this.chain = chain.chain;
}
private static ObservationWebFilterChain initChain(List<ObservationWebFilter> filters, WebHandler handler) {
ObservationWebFilterChain chain = new ObservationWebFilterChain(handler, null, null);
ListIterator<? extends ObservationWebFilter> iterator = filters.listIterator(filters.size());
while (iterator.hasPrevious()) {
chain = new ObservationWebFilterChain(handler, iterator.previous(), chain);
}
return chain;
}
/**
* Private constructor to represent one link in the chain.
*/
private ObservationWebFilterChain(WebHandler handler, @Nullable ObservationWebFilter currentFilter,
@Nullable ObservationWebFilterChain chain) {
this.currentFilter = currentFilter;
this.handler = handler;
this.chain = chain;
}
@Override
public Mono<Void> filter(ServerWebExchange exchange) {
return Mono.defer(() -> (this.currentFilter != null && this.chain != null)
? invokeFilter(this.currentFilter, this.chain, exchange) : this.handler.handle(exchange));
}
private Mono<Void> invokeFilter(ObservationWebFilter current, ObservationWebFilterChain chain,
ServerWebExchange exchange) {
String currentName = current.getName();
return current.filter(exchange, chain).checkpoint(currentName + " [DefaultWebFilterChain]");
}
}
static final class ObservationWebFilter implements WebFilter {
private final ObservationRegistry registry;
private final WebFilterChainObservationConvention convention = new WebFilterChainObservationConvention();
private final WebFilter filter;
private final String name;
private final int position;
private final int size;
ObservationWebFilter(ObservationRegistry registry, WebFilter filter, int position, int size) {
this.registry = registry;
this.filter = filter;
this.name = filter.getClass().getSimpleName();
this.position = position;
this.size = size;
}
String getName() {
return this.name;
}
@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
if (this.position == 1) {
AroundWebFilterObservation parent = parent(exchange);
return parent.wrap(this::wrapFilter).filter(exchange, chain);
}
else {
return wrapFilter(exchange, chain);
}
}
private Mono<Void> wrapFilter(ServerWebExchange exchange, WebFilterChain chain) {
AroundWebFilterObservation parent = observation(exchange);
WebFilterChainObservationContext parentBefore = (WebFilterChainObservationContext) parent.before()
.getContext();
parentBefore.setChainSize(this.size);
parentBefore.setFilterName(this.name);
parentBefore.setChainPosition(this.position);
return this.filter.filter(exchange, chain).doOnSuccess((result) -> {
parent.start();
WebFilterChainObservationContext parentAfter = (WebFilterChainObservationContext) parent.after()
.getContext();
parentAfter.setChainSize(this.size);
parentAfter.setFilterName(this.name);
parentAfter.setChainPosition(this.size - this.position + 1);
});
}
private AroundWebFilterObservation parent(ServerWebExchange exchange) {
WebFilterChainObservationContext beforeContext = WebFilterChainObservationContext.before(exchange);
WebFilterChainObservationContext afterContext = WebFilterChainObservationContext.after(exchange);
Observation before = Observation.createNotStarted(this.convention, () -> beforeContext, this.registry);
Observation after = Observation.createNotStarted(this.convention, () -> afterContext, this.registry);
AroundWebFilterObservation parent = AroundWebFilterObservation.create(before, after);
exchange.getAttributes().put(ATTRIBUTE, parent);
return parent;
}
}
interface AroundWebFilterObservation extends WebFilterObservation {
AroundWebFilterObservation NOOP = new AroundWebFilterObservation() {
};
static AroundWebFilterObservation create(Observation before, Observation after) {
if (before.isNoop() || after.isNoop()) {
return NOOP;
}
return new SimpleAroundWebFilterObservation(before, after);
}
default Observation before() {
return Observation.NOOP;
}
default Observation after() {
return Observation.NOOP;
}
class SimpleAroundWebFilterObservation implements AroundWebFilterObservation {
private final Iterator<Observation> observations;
private final Observation before;
private final Observation after;
private final AtomicReference<Observation> currentObservation = new AtomicReference<>(null);
SimpleAroundWebFilterObservation(Observation before, Observation after) {
this.before = before;
this.after = after;
this.observations = Arrays.asList(before, after).iterator();
}
@Override
public void start() {
if (this.observations.hasNext()) {
stop();
Observation observation = this.observations.next();
observation.start();
this.currentObservation.set(observation);
}
}
@Override
public void error(Throwable ex) {
Observation observation = this.currentObservation.get();
if (observation == null) {
return;
}
observation.error(ex);
}
@Override
public void stop() {
Observation observation = this.currentObservation.getAndSet(null);
if (observation == null) {
return;
}
observation.stop();
}
@Override
public WebFilterChain wrap(WebFilterChain chain) {
return (exchange) -> {
stop();
// @formatter:off
return chain.filter(exchange)
.doOnSuccess((v) -> start())
.doOnCancel(this::start)
.doOnError((t) -> {
error(t);
start();
});
// @formatter:on
};
}
@Override
public WebFilter wrap(WebFilter filter) {
return (exchange, chain) -> {
start();
// @formatter:off
return filter.filter(exchange, chain)
.doOnSuccess((v) -> stop())
.doOnCancel(this::stop)
.doOnError((t) -> {
error(t);
stop();
});
// @formatter:on
};
}
@Override
public Observation before() {
return this.before;
}
@Override
public Observation after() {
return this.after;
}
}
}
interface WebFilterObservation {
WebFilterObservation NOOP = new WebFilterObservation() {
};
static WebFilterObservation create(Observation observation) {
if (observation.isNoop()) {
return NOOP;
}
return new SimpleWebFilterObservation(observation);
}
default void start() {
}
default void error(Throwable ex) {
}
default void stop() {
}
default WebFilter wrap(WebFilter filter) {
return filter;
}
default WebFilterChain wrap(WebFilterChain chain) {
return chain;
}
class SimpleWebFilterObservation implements WebFilterObservation {
private final Observation observation;
SimpleWebFilterObservation(Observation observation) {
this.observation = observation;
}
@Override
public void start() {
this.observation.start();
}
@Override
public void error(Throwable ex) {
this.observation.error(ex);
}
@Override
public void stop() {
this.observation.stop();
}
@Override
public WebFilter wrap(WebFilter filter) {
if (this.observation.isNoop()) {
return filter;
}
return (exchange, chain) -> {
this.observation.start();
return filter.filter(exchange, chain).doOnSuccess((v) -> this.observation.stop())
.doOnCancel(this.observation::stop).doOnError((t) -> {
this.observation.error(t);
this.observation.stop();
});
};
}
@Override
public WebFilterChain wrap(WebFilterChain chain) {
if (this.observation.isNoop()) {
return chain;
}
return (exchange) -> {
this.observation.start();
return chain.filter(exchange).doOnSuccess((v) -> this.observation.stop())
.doOnCancel(this.observation::stop).doOnError((t) -> {
this.observation.error(t);
this.observation.stop();
});
};
}
}
}
static final class WebFilterChainObservationContext extends Observation.Context {
private final ServerWebExchange exchange;
private final String filterSection;
private String filterName;
private int chainPosition;
private int chainSize;
private WebFilterChainObservationContext(ServerWebExchange exchange, String filterSection) {
this.exchange = exchange;
this.filterSection = filterSection;
}
static WebFilterChainObservationContext before(ServerWebExchange exchange) {
return new WebFilterChainObservationContext(exchange, "before");
}
static WebFilterChainObservationContext after(ServerWebExchange exchange) {
return new WebFilterChainObservationContext(exchange, "after");
}
@Override
public void setName(String name) {
super.setName(name);
if (name != null) {
setContextualName(name + "." + this.filterSection);
}
}
String getRequestLine() {
return this.exchange.getRequest().getPath().toString();
}
String getFilterSection() {
return this.filterSection;
}
String getFilterName() {
return this.filterName;
}
void setFilterName(String filterName) {
this.filterName = filterName;
}
int getChainPosition() {
return this.chainPosition;
}
void setChainPosition(int chainPosition) {
this.chainPosition = chainPosition;
}
int getChainSize() {
return this.chainSize;
}
void setChainSize(int chainSize) {
this.chainSize = chainSize;
}
}
static final class WebFilterChainObservationConvention
implements ObservationConvention<WebFilterChainObservationContext> {
static final String CHAIN_OBSERVATION_NAME = "spring.security.http.chains";
private static final String REQUEST_LINE_NAME = "request.line";
private static final String CHAIN_POSITION_NAME = "chain.position";
private static final String CHAIN_SIZE_NAME = "chain.size";
private static final String FILTER_SECTION_NAME = "filter.section";
private static final String FILTER_NAME = "current.filter.name";
@Override
public String getName() {
return CHAIN_OBSERVATION_NAME;
}
@Override
public KeyValues getLowCardinalityKeyValues(WebFilterChainObservationContext context) {
KeyValues kv = KeyValues.of(CHAIN_SIZE_NAME, String.valueOf(context.getChainSize()))
.and(CHAIN_POSITION_NAME, String.valueOf(context.getChainPosition()))
.and(FILTER_SECTION_NAME, context.getFilterSection());
if (context.getFilterName() != null) {
kv = kv.and(FILTER_NAME, context.getFilterName());
}
return kv;
}
@Override
public KeyValues getHighCardinalityKeyValues(WebFilterChainObservationContext context) {
String requestLine = context.getRequestLine();
return KeyValues.of(REQUEST_LINE_NAME, requestLine);
}
@Override
public boolean supportsContext(Observation.Context context) {
return context instanceof WebFilterChainObservationContext;
}
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,11 +17,15 @@
package org.springframework.security.web.server;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import jakarta.servlet.FilterChain;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.util.Assert;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
@ -37,6 +41,8 @@ public class WebFilterChainProxy implements WebFilter {
private final List<SecurityWebFilterChain> filters;
private WebFilterChainDecorator filterChainDecorator = new DefaultWebFilterChainDecorator();
public WebFilterChainProxy(List<SecurityWebFilterChain> filters) {
this.filters = filters;
}
@ -49,10 +55,82 @@ public class WebFilterChainProxy implements WebFilter {
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
return Flux.fromIterable(this.filters)
.filterWhen((securityWebFilterChain) -> securityWebFilterChain.matches(exchange)).next()
.switchIfEmpty(chain.filter(exchange).then(Mono.empty()))
.switchIfEmpty(
Mono.defer(() -> this.filterChainDecorator.decorate(chain).filter(exchange).then(Mono.empty())))
.flatMap((securityWebFilterChain) -> securityWebFilterChain.getWebFilters().collectList())
.map((filters) -> new DefaultWebFilterChain(chain::filter, filters))
.map((filters) -> this.filterChainDecorator.decorate(chain, filters))
.flatMap((securedChain) -> securedChain.filter(exchange));
}
/**
* Used to decorate the original {@link FilterChain} for each request
*
* <p>
* By default, this decorates the filter chain with a {@link DefaultWebFilterChain}
* that iterates through security filters and then delegates to the original chain
* @param filterChainDecorator the strategy for constructing the filter chain
* @since 6.0
*/
public void setFilterChainDecorator(WebFilterChainDecorator filterChainDecorator) {
Assert.notNull(filterChainDecorator, "filterChainDecorator cannot be null");
this.filterChainDecorator = filterChainDecorator;
}
/**
* A strategy for decorating the provided filter chain with one that accounts for the
* {@link SecurityFilterChain} for a given request.
*
* @author Josh Cummings
* @since 6.0
*/
public interface WebFilterChainDecorator {
/**
* Provide a new {@link FilterChain} that accounts for needed security
* considerations when there are no security filters.
* @param original the original {@link FilterChain}
* @return a security-enabled {@link FilterChain}
*/
default WebFilterChain decorate(WebFilterChain original) {
return decorate(original, Collections.emptyList());
}
/**
* Provide a new {@link FilterChain} that accounts for the provided filters as
* well as teh original filter chain.
* @param original the original {@link FilterChain}
* @param filters the security filters
* @return a security-enabled {@link FilterChain} that includes the provided
* filters
*/
WebFilterChain decorate(WebFilterChain original, List<WebFilter> filters);
}
/**
* A {@link WebFilterChainDecorator} that uses the {@link DefaultWebFilterChain}
*
* @author Josh Cummings
* @since 6.0
*/
public static class DefaultWebFilterChainDecorator implements WebFilterChainDecorator {
/**
* {@inheritDoc}
*/
@Override
public WebFilterChain decorate(WebFilterChain original) {
return original;
}
/**
* {@inheritDoc}
*/
@Override
public WebFilterChain decorate(WebFilterChain original, List<WebFilter> filters) {
return new DefaultWebFilterChain(original::filter, filters);
}
}
}

View File

@ -16,19 +16,27 @@
package org.springframework.security.web;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationHandler;
import io.micrometer.observation.ObservationRegistry;
import jakarta.servlet.Filter;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletRequestWrapper;
import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.springframework.mock.web.MockHttpServletRequest;
@ -50,7 +58,9 @@ import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.willAnswer;
import static org.mockito.BDDMockito.willThrow;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
@ -270,10 +280,198 @@ public class FilterChainProxyTests {
this.fcp.setRequestRejectedHandler(rjh);
RequestRejectedException requestRejectedException = new RequestRejectedException("Contains illegal chars");
ServletException servletException = new ServletException(requestRejectedException);
given(fw.getFirewalledRequest(this.request)).willReturn(mock(FirewalledRequest.class));
given(fw.getFirewalledRequest(this.request)).willReturn(new MockFirewalledRequest(this.request));
willThrow(servletException).given(this.chain).doFilter(any(), any());
this.fcp.doFilter(this.request, this.response, this.chain);
verify(rjh).handle(eq(this.request), eq(this.response), eq((requestRejectedException)));
}
@Test
public void doFilterWhenMatchesThenObservationRegistryObserves() throws Exception {
ObservationHandler<Observation.Context> handler = mock(ObservationHandler.class);
given(handler.supportsContext(any())).willReturn(true);
ObservationRegistry registry = ObservationRegistry.create();
registry.observationConfig().observationHandler(handler);
given(this.matcher.matches(any())).willReturn(true);
SecurityFilterChain sec = new DefaultSecurityFilterChain(this.matcher, Arrays.asList(this.filter));
FilterChainProxy fcp = new FilterChainProxy(sec);
fcp.setFilterChainDecorator(new ObservationFilterChainDecorator(registry));
Filter filter = ObservationFilterChainDecorator.FilterObservation
.create(Observation.createNotStarted("wrap", registry)).wrap(fcp);
filter.doFilter(this.request, this.response, this.chain);
ArgumentCaptor<Observation.Context> captor = ArgumentCaptor.forClass(Observation.Context.class);
verify(handler, times(4)).onStart(captor.capture());
verify(handler, times(4)).onStop(any());
Iterator<Observation.Context> contexts = captor.getAllValues().iterator();
assertThat(contexts.next().getName()).isEqualTo("wrap");
assertFilterChainObservation(contexts.next(), "before", 1);
assertThat(contexts.next().getName()).isEqualTo(ObservationFilterChainDecorator.SECURED_OBSERVATION_NAME);
assertFilterChainObservation(contexts.next(), "after", 1);
}
@Test
public void doFilterWhenMultipleFiltersThenObservationRegistryObserves() throws Exception {
ObservationHandler<Observation.Context> handler = mock(ObservationHandler.class);
given(handler.supportsContext(any())).willReturn(true);
ObservationRegistry registry = ObservationRegistry.create();
registry.observationConfig().observationHandler(handler);
given(this.matcher.matches(any())).willReturn(true);
Filter one = mockFilter();
Filter two = mockFilter();
Filter three = mockFilter();
SecurityFilterChain sec = new DefaultSecurityFilterChain(this.matcher, one, two, three);
FilterChainProxy fcp = new FilterChainProxy(sec);
fcp.setFilterChainDecorator(new ObservationFilterChainDecorator(registry));
Filter filter = ObservationFilterChainDecorator.FilterObservation
.create(Observation.createNotStarted("wrap", registry)).wrap(fcp);
filter.doFilter(this.request, this.response, this.chain);
ArgumentCaptor<Observation.Context> captor = ArgumentCaptor.forClass(Observation.Context.class);
verify(handler, times(4)).onStart(captor.capture());
verify(handler, times(4)).onStop(any());
Iterator<Observation.Context> contexts = captor.getAllValues().iterator();
assertThat(contexts.next().getName()).isEqualTo("wrap");
assertFilterChainObservation(contexts.next(), "before", 3);
assertThat(contexts.next().getName()).isEqualTo(ObservationFilterChainDecorator.SECURED_OBSERVATION_NAME);
assertFilterChainObservation(contexts.next(), "after", 3);
}
@Test
public void doFilterWhenMismatchesThenObservationRegistryObserves() throws Exception {
ObservationHandler<Observation.Context> handler = mock(ObservationHandler.class);
given(handler.supportsContext(any())).willReturn(true);
ObservationRegistry registry = ObservationRegistry.create();
registry.observationConfig().observationHandler(handler);
SecurityFilterChain sec = new DefaultSecurityFilterChain(this.matcher, Arrays.asList(this.filter));
FilterChainProxy fcp = new FilterChainProxy(sec);
fcp.setFilterChainDecorator(new ObservationFilterChainDecorator(registry));
Filter filter = ObservationFilterChainDecorator.FilterObservation
.create(Observation.createNotStarted("wrap", registry)).wrap(fcp);
filter.doFilter(this.request, this.response, this.chain);
ArgumentCaptor<Observation.Context> captor = ArgumentCaptor.forClass(Observation.Context.class);
verify(handler, times(2)).onStart(captor.capture());
verify(handler, times(2)).onStop(any());
Iterator<Observation.Context> contexts = captor.getAllValues().iterator();
assertThat(contexts.next().getName()).isEqualTo("wrap");
assertThat(contexts.next().getName()).isEqualTo(ObservationFilterChainDecorator.UNSECURED_OBSERVATION_NAME);
}
@Test
public void doFilterWhenFilterExceptionThenObservationRegistryObserves() throws Exception {
ObservationHandler<Observation.Context> handler = mock(ObservationHandler.class);
given(handler.supportsContext(any())).willReturn(true);
ObservationRegistry registry = ObservationRegistry.create();
registry.observationConfig().observationHandler(handler);
willThrow(IllegalStateException.class).given(this.filter).doFilter(any(), any(), any());
given(this.matcher.matches(any())).willReturn(true);
SecurityFilterChain sec = new DefaultSecurityFilterChain(this.matcher, Arrays.asList(this.filter));
FilterChainProxy fcp = new FilterChainProxy(sec);
fcp.setFilterChainDecorator(new ObservationFilterChainDecorator(registry));
Filter filter = ObservationFilterChainDecorator.FilterObservation
.create(Observation.createNotStarted("wrap", registry)).wrap(fcp);
assertThatExceptionOfType(IllegalStateException.class)
.isThrownBy(() -> filter.doFilter(this.request, this.response, this.chain));
ArgumentCaptor<Observation.Context> captor = ArgumentCaptor.forClass(Observation.Context.class);
verify(handler, times(2)).onStart(captor.capture());
verify(handler, times(2)).onStop(any());
verify(handler, atLeastOnce()).onError(any());
Iterator<Observation.Context> contexts = captor.getAllValues().iterator();
assertThat(contexts.next().getName()).isEqualTo("wrap");
assertFilterChainObservation(contexts.next(), "before", 1);
}
@Test
public void doFilterWhenExceptionWithMultipleFiltersThenObservationRegistryObserves() throws Exception {
ObservationHandler<Observation.Context> handler = mock(ObservationHandler.class);
given(handler.supportsContext(any())).willReturn(true);
ObservationRegistry registry = ObservationRegistry.create();
registry.observationConfig().observationHandler(handler);
given(this.matcher.matches(any())).willReturn(true);
Filter one = mockFilter();
Filter two = mock(Filter.class);
willThrow(IllegalStateException.class).given(two).doFilter(any(), any(), any());
Filter three = mockFilter();
SecurityFilterChain sec = new DefaultSecurityFilterChain(this.matcher, one, two, three);
FilterChainProxy fcp = new FilterChainProxy(sec);
fcp.setFilterChainDecorator(new ObservationFilterChainDecorator(registry));
Filter filter = ObservationFilterChainDecorator.FilterObservation
.create(Observation.createNotStarted("wrap", registry)).wrap(fcp);
assertThatExceptionOfType(IllegalStateException.class)
.isThrownBy(() -> filter.doFilter(this.request, this.response, this.chain));
ArgumentCaptor<Observation.Context> captor = ArgumentCaptor.forClass(Observation.Context.class);
verify(handler, times(2)).onStart(captor.capture());
verify(handler, times(2)).onStop(any());
Iterator<Observation.Context> contexts = captor.getAllValues().iterator();
assertThat(contexts.next().getName()).isEqualTo("wrap");
assertFilterChainObservation(contexts.next(), "before", 2);
}
@Test
public void doFilterWhenOneFilterDoesNotProceedThenObservationRegistryObserves() throws Exception {
ObservationHandler<Observation.Context> handler = mock(ObservationHandler.class);
given(handler.supportsContext(any())).willReturn(true);
ObservationRegistry registry = ObservationRegistry.create();
registry.observationConfig().observationHandler(handler);
given(this.matcher.matches(any())).willReturn(true);
Filter one = mockFilter();
Filter two = mock(Filter.class);
Filter three = mockFilter();
SecurityFilterChain sec = new DefaultSecurityFilterChain(this.matcher, one, two, three);
FilterChainProxy fcp = new FilterChainProxy(sec);
fcp.setFilterChainDecorator(new ObservationFilterChainDecorator(registry));
Filter filter = ObservationFilterChainDecorator.FilterObservation
.create(Observation.createNotStarted("wrap", registry)).wrap(fcp);
filter.doFilter(this.request, this.response, this.chain);
ArgumentCaptor<Observation.Context> captor = ArgumentCaptor.forClass(Observation.Context.class);
verify(handler, times(3)).onStart(captor.capture());
verify(handler, times(3)).onStop(any());
Iterator<Observation.Context> contexts = captor.getAllValues().iterator();
assertThat(contexts.next().getName()).isEqualTo("wrap");
assertFilterChainObservation(contexts.next(), "before", 2);
assertFilterChainObservation(contexts.next(), "after", 3);
}
static void assertFilterChainObservation(Observation.Context context, String filterSection, int chainPosition) {
assertThat(context).isInstanceOf(ObservationFilterChainDecorator.FilterChainObservationContext.class);
ObservationFilterChainDecorator.FilterChainObservationContext filterChainObservationContext = (ObservationFilterChainDecorator.FilterChainObservationContext) context;
assertThat(context.getName())
.isEqualTo(ObservationFilterChainDecorator.FilterChainObservationConvention.CHAIN_OBSERVATION_NAME);
assertThat(context.getContextualName()).endsWith(filterSection);
assertThat(filterChainObservationContext.getChainPosition()).isEqualTo(chainPosition);
}
static Filter mockFilter() throws Exception {
Filter filter = mock(Filter.class);
willAnswer((invocation) -> {
HttpServletRequest request = invocation.getArgument(0, HttpServletRequest.class);
HttpServletResponse response = invocation.getArgument(1, HttpServletResponse.class);
FilterChain chain = invocation.getArgument(2, FilterChain.class);
chain.doFilter(request, response);
return null;
}).given(filter).doFilter(any(), any(), any());
return filter;
}
private static class MockFirewalledRequest extends FirewalledRequest {
MockFirewalledRequest(HttpServletRequest request) {
super(request);
}
@Override
public void reset() {
}
}
private static class MockFilter implements Filter {
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
chain.doFilter(request, response);
}
}
}

View File

@ -17,12 +17,22 @@
package org.springframework.security.web.server;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationHandler;
import io.micrometer.observation.ObservationRegistry;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import reactor.core.publisher.Mono;
import org.springframework.http.HttpStatus;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.web.server.ObservationWebFilterChainDecorator.WebFilterChainObservationContext;
import org.springframework.security.web.server.ObservationWebFilterChainDecorator.WebFilterChainObservationConvention;
import org.springframework.security.web.server.ObservationWebFilterChainDecorator.WebFilterObservation;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult;
import org.springframework.test.web.reactive.server.WebTestClient;
@ -30,6 +40,15 @@ import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/**
* @author Rob Winch
* @since 5.0
@ -47,6 +66,86 @@ public class WebFilterChainProxyTests {
.isNotFound();
}
@Test
public void doFilterWhenMatchesThenObservationRegistryObserves() {
ObservationHandler<Observation.Context> handler = mock(ObservationHandler.class);
given(handler.supportsContext(any())).willReturn(true);
ObservationRegistry registry = ObservationRegistry.create();
registry.observationConfig().observationHandler(handler);
List<WebFilter> filters = Arrays.asList(new PassthroughWebFilter());
ServerWebExchangeMatcher match = (exchange) -> MatchResult.match();
MatcherSecurityWebFilterChain chain = new MatcherSecurityWebFilterChain(match, filters);
WebFilterChainProxy fcp = new WebFilterChainProxy(chain);
fcp.setFilterChainDecorator(new ObservationWebFilterChainDecorator(registry));
WebFilter filter = WebFilterObservation.create(Observation.createNotStarted("wrap", registry)).wrap(fcp);
WebFilterChain mockChain = mock(WebFilterChain.class);
given(mockChain.filter(any())).willReturn(Mono.empty());
filter.filter(MockServerWebExchange.from(MockServerHttpRequest.get("/")), mockChain).block();
ArgumentCaptor<Observation.Context> captor = ArgumentCaptor.forClass(Observation.Context.class);
verify(handler, times(4)).onStart(captor.capture());
Iterator<Observation.Context> contexts = captor.getAllValues().iterator();
assertThat(contexts.next().getName()).isEqualTo("wrap");
assertFilterChainObservation(contexts.next(), "before", 1);
assertThat(contexts.next().getName()).isEqualTo(ObservationWebFilterChainDecorator.SECURED_OBSERVATION_NAME);
assertFilterChainObservation(contexts.next(), "after", 1);
}
@Test
public void doFilterWhenMismatchesThenObservationRegistryObserves() {
ObservationHandler<Observation.Context> handler = mock(ObservationHandler.class);
given(handler.supportsContext(any())).willReturn(true);
ObservationRegistry registry = ObservationRegistry.create();
registry.observationConfig().observationHandler(handler);
List<WebFilter> filters = Arrays.asList(new PassthroughWebFilter());
ServerWebExchangeMatcher notMatch = (exchange) -> MatchResult.notMatch();
MatcherSecurityWebFilterChain chain = new MatcherSecurityWebFilterChain(notMatch, filters);
WebFilterChainProxy fcp = new WebFilterChainProxy(chain);
fcp.setFilterChainDecorator(new ObservationWebFilterChainDecorator(registry));
WebFilter filter = WebFilterObservation.create(Observation.createNotStarted("wrap", registry)).wrap(fcp);
WebFilterChain mockChain = mock(WebFilterChain.class);
given(mockChain.filter(any())).willReturn(Mono.empty());
filter.filter(MockServerWebExchange.from(MockServerHttpRequest.get("/")), mockChain).block();
ArgumentCaptor<Observation.Context> captor = ArgumentCaptor.forClass(Observation.Context.class);
verify(handler, times(2)).onStart(captor.capture());
Iterator<Observation.Context> contexts = captor.getAllValues().iterator();
assertThat(contexts.next().getName()).isEqualTo("wrap");
assertThat(contexts.next().getName()).isEqualTo(ObservationWebFilterChainDecorator.UNSECURED_OBSERVATION_NAME);
}
@Test
public void doFilterWhenFilterExceptionThenObservationRegistryObserves() {
ObservationHandler<Observation.Context> handler = mock(ObservationHandler.class);
given(handler.supportsContext(any())).willReturn(true);
ObservationRegistry registry = ObservationRegistry.create();
registry.observationConfig().observationHandler(handler);
WebFilter error = mock(WebFilter.class);
given(error.filter(any(), any())).willReturn(Mono.error(new IllegalStateException()));
List<WebFilter> filters = Arrays.asList(error);
ServerWebExchangeMatcher match = (exchange) -> MatchResult.match();
MatcherSecurityWebFilterChain chain = new MatcherSecurityWebFilterChain(match, filters);
WebFilterChainProxy fcp = new WebFilterChainProxy(chain);
fcp.setFilterChainDecorator(new ObservationWebFilterChainDecorator(registry));
WebFilter filter = WebFilterObservation.create(Observation.createNotStarted("wrap", registry)).wrap(fcp);
WebFilterChain mockChain = mock(WebFilterChain.class);
given(mockChain.filter(any())).willReturn(Mono.empty());
assertThatExceptionOfType(IllegalStateException.class).isThrownBy(
() -> filter.filter(MockServerWebExchange.from(MockServerHttpRequest.get("/")), mockChain).block());
ArgumentCaptor<Observation.Context> captor = ArgumentCaptor.forClass(Observation.Context.class);
verify(handler, times(2)).onStart(captor.capture());
verify(handler, atLeastOnce()).onError(any());
Iterator<Observation.Context> contexts = captor.getAllValues().iterator();
assertThat(contexts.next().getName()).isEqualTo("wrap");
assertFilterChainObservation(contexts.next(), "before", 1);
}
static void assertFilterChainObservation(Observation.Context context, String filterSection, int chainPosition) {
assertThat(context).isInstanceOf(WebFilterChainObservationContext.class);
WebFilterChainObservationContext filterChainObservationContext = (WebFilterChainObservationContext) context;
assertThat(context.getName()).isEqualTo(WebFilterChainObservationConvention.CHAIN_OBSERVATION_NAME);
assertThat(context.getContextualName()).endsWith(filterSection);
assertThat(filterChainObservationContext.getChainPosition()).isEqualTo(chainPosition);
}
static class Http200WebFilter implements WebFilter {
@Override
@ -56,4 +155,13 @@ public class WebFilterChainProxyTests {
}
static class PassthroughWebFilter implements WebFilter {
@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
return chain.filter(exchange);
}
}
}