From adf3e94c9f2b15de45e5ce1e254f2ffd4f48bd70 Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Wed, 14 Apr 2021 21:01:09 -0500 Subject: [PATCH] Fix HttpSecurity.addFilter* Ordering Closes gh-9633 --- ...ator.java => FilterOrderRegistration.java} | 78 +---------- .../annotation/web/builders/HttpSecurity.java | 81 ++++++++--- .../builders/HttpSecurityAddFilterTest.java | 132 ++++++++++++++++++ 3 files changed, 199 insertions(+), 92 deletions(-) rename config/src/main/java/org/springframework/security/config/annotation/web/builders/{FilterComparator.java => FilterOrderRegistration.java} (73%) create mode 100644 config/src/test/java/org/springframework/security/config/annotation/web/builders/HttpSecurityAddFilterTest.java diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterComparator.java b/config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterOrderRegistration.java similarity index 73% rename from config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterComparator.java rename to config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterOrderRegistration.java index ab1fd36791..e37c14ae66 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterComparator.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterOrderRegistration.java @@ -15,7 +15,6 @@ */ package org.springframework.security.config.annotation.web.builders; -import java.io.Serializable; import java.util.Comparator; import java.util.HashMap; import java.util.Map; @@ -53,14 +52,12 @@ import org.springframework.web.filter.CorsFilter; * @author Rob Winch * @since 3.2 */ - -@SuppressWarnings("serial") -final class FilterComparator implements Comparator, Serializable { +final class FilterOrderRegistration { private static final int INITIAL_ORDER = 100; private static final int ORDER_STEP = 100; private final Map filterToOrder = new HashMap<>(); - FilterComparator() { + FilterOrderRegistration() { Step order = new Step(INITIAL_ORDER, ORDER_STEP); put(ChannelProcessingFilter.class, order.next()); put(ConcurrentSessionFilter.class, order.next()); @@ -111,75 +108,6 @@ final class FilterComparator implements Comparator, Serializable { put(SwitchUserFilter.class, order.next()); } - public int compare(Filter lhs, Filter rhs) { - Integer left = getOrder(lhs.getClass()); - Integer right = getOrder(rhs.getClass()); - return left - right; - } - - /** - * Determines if a particular {@link Filter} is registered to be sorted - * - * @param filter - * @return - */ - public boolean isRegistered(Class filter) { - return getOrder(filter) != null; - } - - /** - * Registers a {@link Filter} to exist after a particular {@link Filter} that is - * already registered. - * @param filter the {@link Filter} to register - * @param afterFilter the {@link Filter} that is already registered and that - * {@code filter} should be placed after. - */ - public void registerAfter(Class filter, - Class afterFilter) { - Integer position = getOrder(afterFilter); - if (position == null) { - throw new IllegalArgumentException( - "Cannot register after unregistered Filter " + afterFilter); - } - - put(filter, position + 1); - } - - /** - * Registers a {@link Filter} to exist at a particular {@link Filter} position - * @param filter the {@link Filter} to register - * @param atFilter the {@link Filter} that is already registered and that - * {@code filter} should be placed at. - */ - public void registerAt(Class filter, - Class atFilter) { - Integer position = getOrder(atFilter); - if (position == null) { - throw new IllegalArgumentException( - "Cannot register after unregistered Filter " + atFilter); - } - - put(filter, position); - } - - /** - * Registers a {@link Filter} to exist before a particular {@link Filter} that is - * already registered. - * @param filter the {@link Filter} to register - * @param beforeFilter the {@link Filter} that is already registered and that - * {@code filter} should be placed before. - */ - public void registerBefore(Class filter, - Class beforeFilter) { - Integer position = getOrder(beforeFilter); - if (position == null) { - throw new IllegalArgumentException( - "Cannot register after unregistered Filter " + beforeFilter); - } - - put(filter, position - 1); - } - private void put(Class filter, int position) { String className = filter.getName(); filterToOrder.put(className, position); @@ -192,7 +120,7 @@ final class FilterComparator implements Comparator, Serializable { * @param clazz the {@link Filter} class to determine the sort order * @return the sort order or null if not defined */ - private Integer getOrder(Class clazz) { + Integer getOrder(Class clazz) { while (clazz != null) { Integer result = filterToOrder.get(clazz.getName()); if (result != null) { diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/builders/HttpSecurity.java b/config/src/main/java/org/springframework/security/config/annotation/web/builders/HttpSecurity.java index bce40f4cd7..31a0e8686a 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/builders/HttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/builders/HttpSecurity.java @@ -16,6 +16,8 @@ package org.springframework.security.config.annotation.web.builders; import org.springframework.context.ApplicationContext; +import org.springframework.core.OrderComparator; +import org.springframework.core.Ordered; import org.springframework.http.HttpMethod; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationProvider; @@ -78,10 +80,16 @@ import org.springframework.web.cors.CorsConfiguration; import org.springframework.web.filter.CorsFilter; import org.springframework.web.servlet.handler.HandlerMappingIntrospector; +import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; + import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; /** @@ -125,9 +133,9 @@ public final class HttpSecurity extends implements SecurityBuilder, HttpSecurityBuilder { private final RequestMatcherConfigurer requestMatcherConfigurer; - private List filters = new ArrayList<>(); + private List filters = new ArrayList<>(); private RequestMatcher requestMatcher = AnyRequestMatcher.INSTANCE; - private FilterComparator comparator = new FilterComparator(); + private FilterOrderRegistration filterOrders = new FilterOrderRegistration(); /** * Creates a new instance @@ -2528,8 +2536,12 @@ public final class HttpSecurity extends @Override protected DefaultSecurityFilterChain performBuild() { - filters.sort(comparator); - return new DefaultSecurityFilterChain(requestMatcher, filters); + this.filters.sort(OrderComparator.INSTANCE); + List sortedFilters = new ArrayList<>(this.filters.size()); + for (Filter filter : this.filters) { + sortedFilters.add(((OrderedFilter) filter).filter); + } + return new DefaultSecurityFilterChain(this.requestMatcher, sortedFilters); } /* @@ -2570,8 +2582,7 @@ public final class HttpSecurity extends * .servlet.Filter, java.lang.Class) */ public HttpSecurity addFilterAfter(Filter filter, Class afterFilter) { - comparator.registerAfter(filter.getClass(), afterFilter); - return addFilter(filter); + return addFilterAtOffsetOf(filter, 1, afterFilter); } /* @@ -2583,8 +2594,13 @@ public final class HttpSecurity extends */ public HttpSecurity addFilterBefore(Filter filter, Class beforeFilter) { - comparator.registerBefore(filter.getClass(), beforeFilter); - return addFilter(filter); + return addFilterAtOffsetOf(filter, -1, beforeFilter); + } + + private HttpSecurity addFilterAtOffsetOf(Filter filter, int offset, Class registeredFilter) { + int order = this.filterOrders.getOrder(registeredFilter) + offset; + this.filters.add(new OrderedFilter(filter, order)); + return this; } /* @@ -2595,14 +2611,12 @@ public final class HttpSecurity extends * servlet.Filter) */ public HttpSecurity addFilter(Filter filter) { - Class filterClass = filter.getClass(); - if (!comparator.isRegistered(filterClass)) { - throw new IllegalArgumentException( - "The Filter class " - + filterClass.getName() - + " does not have a registered order and cannot be added without a specified order. Consider using addFilterBefore or addFilterAfter instead."); + Integer order = this.filterOrders.getOrder(filter.getClass()); + if (order == null) { + throw new IllegalArgumentException("The Filter class " + filter.getClass().getName() + + " does not have a registered order and cannot be added without a specified order. Consider using addFilterBefore or addFilterAfter instead."); } - this.filters.add(filter); + this.filters.add(new OrderedFilter(filter, order)); return this; } @@ -2626,8 +2640,7 @@ public final class HttpSecurity extends * @return the {@link HttpSecurity} for further customizations */ public HttpSecurity addFilterAt(Filter filter, Class atFilter) { - this.comparator.registerAt(filter.getClass(), atFilter); - return addFilter(filter); + return addFilterAtOffsetOf(filter, 0, atFilter); } /** @@ -3023,4 +3036,38 @@ public final class HttpSecurity extends } return apply(configurer); } + + /* + * A Filter that implements Ordered to be sorted. After sorting occurs, the original + * filter is what is used by FilterChainProxy + */ + private static final class OrderedFilter implements Ordered, Filter { + + private final Filter filter; + + private final int order; + + private OrderedFilter(Filter filter, int order) { + this.filter = filter; + this.order = order; + } + + @Override + public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) + throws IOException, ServletException { + this.filter.doFilter(servletRequest, servletResponse, filterChain); + } + + @Override + public int getOrder() { + return this.order; + } + + @Override + public String toString() { + return "OrderedFilter{" + "filter=" + this.filter + ", order=" + this.order + '}'; + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/builders/HttpSecurityAddFilterTest.java b/config/src/test/java/org/springframework/security/config/annotation/web/builders/HttpSecurityAddFilterTest.java new file mode 100644 index 0000000000..b83818503f --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/annotation/web/builders/HttpSecurityAddFilterTest.java @@ -0,0 +1,132 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.web.builders; + +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; + +import org.assertj.core.api.ListAssert; +import org.junit.Rule; +import org.junit.Test; + +import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; +import org.springframework.security.config.test.SpringTestRule; +import org.springframework.security.web.FilterChainProxy; +import org.springframework.security.web.access.ExceptionTranslationFilter; +import org.springframework.security.web.access.channel.ChannelProcessingFilter; +import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter; +import org.springframework.security.web.context.request.async.WebAsyncManagerIntegrationFilter; + +import static org.assertj.core.api.Assertions.assertThat; + +public class HttpSecurityAddFilterTest { + + @Rule + public final SpringTestRule spring = new SpringTestRule(); + + @Test + public void addFilterAfterWhenSameFilterDifferentPlacesThenOrderCorrect() { + this.spring.register(MyFilterMultipleAfterConfig.class).autowire(); + + assertThatFilters().containsSubsequence(WebAsyncManagerIntegrationFilter.class, MyFilter.class, + ExceptionTranslationFilter.class, MyFilter.class); + } + + @Test + public void addFilterBeforeWhenSameFilterDifferentPlacesThenOrderCorrect() { + this.spring.register(MyFilterMultipleBeforeConfig.class).autowire(); + + assertThatFilters().containsSubsequence(MyFilter.class, WebAsyncManagerIntegrationFilter.class, MyFilter.class, + ExceptionTranslationFilter.class); + } + + @Test + public void addFilterAtWhenSameFilterDifferentPlacesThenOrderCorrect() { + this.spring.register(MyFilterMultipleAtConfig.class).autowire(); + + assertThatFilters().containsSubsequence(MyFilter.class, WebAsyncManagerIntegrationFilter.class, MyFilter.class, + ExceptionTranslationFilter.class); + } + + private ListAssert> assertThatFilters() { + FilterChainProxy filterChain = this.spring.getContext().getBean(FilterChainProxy.class); + List> filters = filterChain.getFilters("/").stream().map(Object::getClass) + .collect(Collectors.toList()); + return assertThat(filters); + } + + public static class MyFilter implements Filter { + + @Override + public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) + throws IOException, ServletException { + filterChain.doFilter(servletRequest, servletResponse); + } + + } + + @EnableWebSecurity + static class MyFilterMultipleAfterConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .addFilterAfter(new MyFilter(), WebAsyncManagerIntegrationFilter.class) + .addFilterAfter(new MyFilter(), ExceptionTranslationFilter.class); + // @formatter:on + } + + } + + @EnableWebSecurity + static class MyFilterMultipleBeforeConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .addFilterBefore(new MyFilter(), WebAsyncManagerIntegrationFilter.class) + .addFilterBefore(new MyFilter(), ExceptionTranslationFilter.class); + // @formatter:on + } + + } + + @EnableWebSecurity + static class MyFilterMultipleAtConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .addFilterAt(new MyFilter(), ChannelProcessingFilter.class) + .addFilterAt(new MyFilter(), UsernamePasswordAuthenticationFilter.class); + // @formatter:on + } + + } + +}