From a31a855146d7485a9e30cdb70a18be212e4d008f Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Wed, 14 Apr 2021 17:06:18 -0500 Subject: [PATCH] Fix HttpSecurity.addFilter* Ordering Closes gh-9633 --- ...ator.java => FilterOrderRegistration.java} | 62 +------- .../annotation/web/builders/HttpSecurity.java | 75 ++++++++-- .../builders/HttpSecurityAddFilterTest.java | 132 ++++++++++++++++++ 3 files changed, 196 insertions(+), 73 deletions(-) rename config/src/main/java/org/springframework/security/config/annotation/web/builders/{FilterComparator.java => FilterOrderRegistration.java} (75%) create mode 100644 config/src/test/java/org/springframework/security/config/annotation/web/builders/HttpSecurityAddFilterTest.java diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterComparator.java b/config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterOrderRegistration.java similarity index 75% rename from config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterComparator.java rename to config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterOrderRegistration.java index c9ffaa80b8..f207044010 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterComparator.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterOrderRegistration.java @@ -16,7 +16,6 @@ package org.springframework.security.config.annotation.web.builders; -import java.io.Serializable; import java.util.Comparator; import java.util.HashMap; import java.util.Map; @@ -47,7 +46,6 @@ import org.springframework.security.web.savedrequest.RequestCacheAwareFilter; import org.springframework.security.web.servletapi.SecurityContextHolderAwareRequestFilter; import org.springframework.security.web.session.ConcurrentSessionFilter; import org.springframework.security.web.session.SessionManagementFilter; -import org.springframework.util.Assert; import org.springframework.web.filter.CorsFilter; /** @@ -59,7 +57,7 @@ import org.springframework.web.filter.CorsFilter; */ @SuppressWarnings("serial") -final class FilterComparator implements Comparator, Serializable { +final class FilterOrderRegistration { private static final int INITIAL_ORDER = 100; @@ -67,7 +65,7 @@ final class FilterComparator implements Comparator, Serializable { private final Map filterToOrder = new HashMap<>(); - FilterComparator() { + FilterOrderRegistration() { Step order = new Step(INITIAL_ORDER, ORDER_STEP); put(ChannelProcessingFilter.class, order.next()); order.next(); // gh-8105 @@ -116,60 +114,6 @@ final class FilterComparator implements Comparator, Serializable { put(SwitchUserFilter.class, order.next()); } - @Override - public int compare(Filter lhs, Filter rhs) { - Integer left = getOrder(lhs.getClass()); - Integer right = getOrder(rhs.getClass()); - return left - right; - } - - /** - * Determines if a particular {@link Filter} is registered to be sorted - * @param filter - * @return - */ - boolean isRegistered(Class filter) { - return getOrder(filter) != null; - } - - /** - * Registers a {@link Filter} to exist after a particular {@link Filter} that is - * already registered. - * @param filter the {@link Filter} to register - * @param afterFilter the {@link Filter} that is already registered and that - * {@code filter} should be placed after. - */ - void registerAfter(Class filter, Class afterFilter) { - Integer position = getOrder(afterFilter); - Assert.notNull(position, () -> "Cannot register after unregistered Filter " + afterFilter); - put(filter, position + 1); - } - - /** - * Registers a {@link Filter} to exist at a particular {@link Filter} position - * @param filter the {@link Filter} to register - * @param atFilter the {@link Filter} that is already registered and that - * {@code filter} should be placed at. - */ - void registerAt(Class filter, Class atFilter) { - Integer position = getOrder(atFilter); - Assert.notNull(position, () -> "Cannot register after unregistered Filter " + atFilter); - put(filter, position); - } - - /** - * Registers a {@link Filter} to exist before a particular {@link Filter} that is - * already registered. - * @param filter the {@link Filter} to register - * @param beforeFilter the {@link Filter} that is already registered and that - * {@code filter} should be placed before. - */ - void registerBefore(Class filter, Class beforeFilter) { - Integer position = getOrder(beforeFilter); - Assert.notNull(position, () -> "Cannot register after unregistered Filter " + beforeFilter); - put(filter, position - 1); - } - private void put(Class filter, int position) { String className = filter.getName(); this.filterToOrder.put(className, position); @@ -181,7 +125,7 @@ final class FilterComparator implements Comparator, Serializable { * @param clazz the {@link Filter} class to determine the sort order * @return the sort order or null if not defined */ - private Integer getOrder(Class clazz) { + Integer getOrder(Class clazz) { while (clazz != null) { Integer result = this.filterToOrder.get(clazz.getName()); if (result != null) { diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/builders/HttpSecurity.java b/config/src/main/java/org/springframework/security/config/annotation/web/builders/HttpSecurity.java index 0b11a22e90..65c2a57e5d 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/builders/HttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/builders/HttpSecurity.java @@ -16,14 +16,21 @@ package org.springframework.security.config.annotation.web.builders; +import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import org.springframework.context.ApplicationContext; +import org.springframework.core.OrderComparator; +import org.springframework.core.Ordered; import org.springframework.http.HttpMethod; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationProvider; @@ -129,11 +136,11 @@ public final class HttpSecurity extends AbstractConfiguredSecurityBuilder filters = new ArrayList<>(); + private List filters = new ArrayList<>(); private RequestMatcher requestMatcher = AnyRequestMatcher.INSTANCE; - private FilterComparator comparator = new FilterComparator(); + private FilterOrderRegistration filterOrders = new FilterOrderRegistration(); /** * Creates a new instance @@ -2609,8 +2616,12 @@ public final class HttpSecurity extends AbstractConfiguredSecurityBuilder sortedFilters = new ArrayList<>(this.filters.size()); + for (Filter filter : this.filters) { + sortedFilters.add(((OrderedFilter) filter).filter); + } + return new DefaultSecurityFilterChain(this.requestMatcher, sortedFilters); } @Override @@ -2631,24 +2642,28 @@ public final class HttpSecurity extends AbstractConfiguredSecurityBuilder afterFilter) { - this.comparator.registerAfter(filter.getClass(), afterFilter); - return addFilter(filter); + return addFilterAtOffsetOf(filter, 1, afterFilter); } @Override public HttpSecurity addFilterBefore(Filter filter, Class beforeFilter) { - this.comparator.registerBefore(filter.getClass(), beforeFilter); - return addFilter(filter); + return addFilterAtOffsetOf(filter, -1, beforeFilter); + } + + private HttpSecurity addFilterAtOffsetOf(Filter filter, int offset, Class registeredFilter) { + int order = this.filterOrders.getOrder(registeredFilter) + offset; + this.filters.add(new OrderedFilter(filter, order)); + return this; } @Override public HttpSecurity addFilter(Filter filter) { - Class filterClass = filter.getClass(); - if (!this.comparator.isRegistered(filterClass)) { - throw new IllegalArgumentException("The Filter class " + filterClass.getName() + Integer order = this.filterOrders.getOrder(filter.getClass()); + if (order == null) { + throw new IllegalArgumentException("The Filter class " + filter.getClass().getName() + " does not have a registered order and cannot be added without a specified order. Consider using addFilterBefore or addFilterAfter instead."); } - this.filters.add(filter); + this.filters.add(new OrderedFilter(filter, order)); return this; } @@ -2671,8 +2686,7 @@ public final class HttpSecurity extends AbstractConfiguredSecurityBuilder atFilter) { - this.comparator.registerAt(filter.getClass(), atFilter); - return addFilter(filter); + return addFilterAtOffsetOf(filter, 0, atFilter); } /** @@ -3060,4 +3074,37 @@ public final class HttpSecurity extends AbstractConfiguredSecurityBuilder> assertThatFilters() { + FilterChainProxy filterChain = this.spring.getContext().getBean(FilterChainProxy.class); + List> filters = filterChain.getFilters("/").stream().map(Object::getClass) + .collect(Collectors.toList()); + return assertThat(filters); + } + + public static class MyFilter implements Filter { + + @Override + public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) + throws IOException, ServletException { + filterChain.doFilter(servletRequest, servletResponse); + } + + } + + @EnableWebSecurity + static class MyFilterMultipleAfterConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .addFilterAfter(new MyFilter(), WebAsyncManagerIntegrationFilter.class) + .addFilterAfter(new MyFilter(), ExceptionTranslationFilter.class); + // @formatter:on + } + + } + + @EnableWebSecurity + static class MyFilterMultipleBeforeConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .addFilterBefore(new MyFilter(), WebAsyncManagerIntegrationFilter.class) + .addFilterBefore(new MyFilter(), ExceptionTranslationFilter.class); + // @formatter:on + } + + } + + @EnableWebSecurity + static class MyFilterMultipleAtConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .addFilterAt(new MyFilter(), ChannelProcessingFilter.class) + .addFilterAt(new MyFilter(), UsernamePasswordAuthenticationFilter.class); + // @formatter:on + } + + } + +}