mirror of
https://github.com/spring-projects/spring-security.git
synced 2025-05-31 17:22:13 +00:00
Fix HttpSecurity.addFilter* Ordering
Closes gh-9633
This commit is contained in:
parent
2b4b856b32
commit
a31a855146
@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
package org.springframework.security.config.annotation.web.builders;
|
package org.springframework.security.config.annotation.web.builders;
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
@ -47,7 +46,6 @@ import org.springframework.security.web.savedrequest.RequestCacheAwareFilter;
|
|||||||
import org.springframework.security.web.servletapi.SecurityContextHolderAwareRequestFilter;
|
import org.springframework.security.web.servletapi.SecurityContextHolderAwareRequestFilter;
|
||||||
import org.springframework.security.web.session.ConcurrentSessionFilter;
|
import org.springframework.security.web.session.ConcurrentSessionFilter;
|
||||||
import org.springframework.security.web.session.SessionManagementFilter;
|
import org.springframework.security.web.session.SessionManagementFilter;
|
||||||
import org.springframework.util.Assert;
|
|
||||||
import org.springframework.web.filter.CorsFilter;
|
import org.springframework.web.filter.CorsFilter;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -59,7 +57,7 @@ import org.springframework.web.filter.CorsFilter;
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
@SuppressWarnings("serial")
|
@SuppressWarnings("serial")
|
||||||
final class FilterComparator implements Comparator<Filter>, Serializable {
|
final class FilterOrderRegistration {
|
||||||
|
|
||||||
private static final int INITIAL_ORDER = 100;
|
private static final int INITIAL_ORDER = 100;
|
||||||
|
|
||||||
@ -67,7 +65,7 @@ final class FilterComparator implements Comparator<Filter>, Serializable {
|
|||||||
|
|
||||||
private final Map<String, Integer> filterToOrder = new HashMap<>();
|
private final Map<String, Integer> filterToOrder = new HashMap<>();
|
||||||
|
|
||||||
FilterComparator() {
|
FilterOrderRegistration() {
|
||||||
Step order = new Step(INITIAL_ORDER, ORDER_STEP);
|
Step order = new Step(INITIAL_ORDER, ORDER_STEP);
|
||||||
put(ChannelProcessingFilter.class, order.next());
|
put(ChannelProcessingFilter.class, order.next());
|
||||||
order.next(); // gh-8105
|
order.next(); // gh-8105
|
||||||
@ -116,60 +114,6 @@ final class FilterComparator implements Comparator<Filter>, Serializable {
|
|||||||
put(SwitchUserFilter.class, order.next());
|
put(SwitchUserFilter.class, order.next());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public int compare(Filter lhs, Filter rhs) {
|
|
||||||
Integer left = getOrder(lhs.getClass());
|
|
||||||
Integer right = getOrder(rhs.getClass());
|
|
||||||
return left - right;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Determines if a particular {@link Filter} is registered to be sorted
|
|
||||||
* @param filter
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
boolean isRegistered(Class<? extends Filter> filter) {
|
|
||||||
return getOrder(filter) != null;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Registers a {@link Filter} to exist after a particular {@link Filter} that is
|
|
||||||
* already registered.
|
|
||||||
* @param filter the {@link Filter} to register
|
|
||||||
* @param afterFilter the {@link Filter} that is already registered and that
|
|
||||||
* {@code filter} should be placed after.
|
|
||||||
*/
|
|
||||||
void registerAfter(Class<? extends Filter> filter, Class<? extends Filter> afterFilter) {
|
|
||||||
Integer position = getOrder(afterFilter);
|
|
||||||
Assert.notNull(position, () -> "Cannot register after unregistered Filter " + afterFilter);
|
|
||||||
put(filter, position + 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Registers a {@link Filter} to exist at a particular {@link Filter} position
|
|
||||||
* @param filter the {@link Filter} to register
|
|
||||||
* @param atFilter the {@link Filter} that is already registered and that
|
|
||||||
* {@code filter} should be placed at.
|
|
||||||
*/
|
|
||||||
void registerAt(Class<? extends Filter> filter, Class<? extends Filter> atFilter) {
|
|
||||||
Integer position = getOrder(atFilter);
|
|
||||||
Assert.notNull(position, () -> "Cannot register after unregistered Filter " + atFilter);
|
|
||||||
put(filter, position);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Registers a {@link Filter} to exist before a particular {@link Filter} that is
|
|
||||||
* already registered.
|
|
||||||
* @param filter the {@link Filter} to register
|
|
||||||
* @param beforeFilter the {@link Filter} that is already registered and that
|
|
||||||
* {@code filter} should be placed before.
|
|
||||||
*/
|
|
||||||
void registerBefore(Class<? extends Filter> filter, Class<? extends Filter> beforeFilter) {
|
|
||||||
Integer position = getOrder(beforeFilter);
|
|
||||||
Assert.notNull(position, () -> "Cannot register after unregistered Filter " + beforeFilter);
|
|
||||||
put(filter, position - 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void put(Class<? extends Filter> filter, int position) {
|
private void put(Class<? extends Filter> filter, int position) {
|
||||||
String className = filter.getName();
|
String className = filter.getName();
|
||||||
this.filterToOrder.put(className, position);
|
this.filterToOrder.put(className, position);
|
||||||
@ -181,7 +125,7 @@ final class FilterComparator implements Comparator<Filter>, Serializable {
|
|||||||
* @param clazz the {@link Filter} class to determine the sort order
|
* @param clazz the {@link Filter} class to determine the sort order
|
||||||
* @return the sort order or null if not defined
|
* @return the sort order or null if not defined
|
||||||
*/
|
*/
|
||||||
private Integer getOrder(Class<?> clazz) {
|
Integer getOrder(Class<?> clazz) {
|
||||||
while (clazz != null) {
|
while (clazz != null) {
|
||||||
Integer result = this.filterToOrder.get(clazz.getName());
|
Integer result = this.filterToOrder.get(clazz.getName());
|
||||||
if (result != null) {
|
if (result != null) {
|
@ -16,14 +16,21 @@
|
|||||||
|
|
||||||
package org.springframework.security.config.annotation.web.builders;
|
package org.springframework.security.config.annotation.web.builders;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import javax.servlet.Filter;
|
import javax.servlet.Filter;
|
||||||
|
import javax.servlet.FilterChain;
|
||||||
|
import javax.servlet.ServletException;
|
||||||
|
import javax.servlet.ServletRequest;
|
||||||
|
import javax.servlet.ServletResponse;
|
||||||
import javax.servlet.http.HttpServletRequest;
|
import javax.servlet.http.HttpServletRequest;
|
||||||
|
|
||||||
import org.springframework.context.ApplicationContext;
|
import org.springframework.context.ApplicationContext;
|
||||||
|
import org.springframework.core.OrderComparator;
|
||||||
|
import org.springframework.core.Ordered;
|
||||||
import org.springframework.http.HttpMethod;
|
import org.springframework.http.HttpMethod;
|
||||||
import org.springframework.security.authentication.AuthenticationManager;
|
import org.springframework.security.authentication.AuthenticationManager;
|
||||||
import org.springframework.security.authentication.AuthenticationProvider;
|
import org.springframework.security.authentication.AuthenticationProvider;
|
||||||
@ -129,11 +136,11 @@ public final class HttpSecurity extends AbstractConfiguredSecurityBuilder<Defaul
|
|||||||
|
|
||||||
private final RequestMatcherConfigurer requestMatcherConfigurer;
|
private final RequestMatcherConfigurer requestMatcherConfigurer;
|
||||||
|
|
||||||
private List<Filter> filters = new ArrayList<>();
|
private List<OrderedFilter> filters = new ArrayList<>();
|
||||||
|
|
||||||
private RequestMatcher requestMatcher = AnyRequestMatcher.INSTANCE;
|
private RequestMatcher requestMatcher = AnyRequestMatcher.INSTANCE;
|
||||||
|
|
||||||
private FilterComparator comparator = new FilterComparator();
|
private FilterOrderRegistration filterOrders = new FilterOrderRegistration();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a new instance
|
* Creates a new instance
|
||||||
@ -2609,8 +2616,12 @@ public final class HttpSecurity extends AbstractConfiguredSecurityBuilder<Defaul
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected DefaultSecurityFilterChain performBuild() {
|
protected DefaultSecurityFilterChain performBuild() {
|
||||||
this.filters.sort(this.comparator);
|
this.filters.sort(OrderComparator.INSTANCE);
|
||||||
return new DefaultSecurityFilterChain(this.requestMatcher, this.filters);
|
List<Filter> sortedFilters = new ArrayList<>(this.filters.size());
|
||||||
|
for (Filter filter : this.filters) {
|
||||||
|
sortedFilters.add(((OrderedFilter) filter).filter);
|
||||||
|
}
|
||||||
|
return new DefaultSecurityFilterChain(this.requestMatcher, sortedFilters);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -2631,24 +2642,28 @@ public final class HttpSecurity extends AbstractConfiguredSecurityBuilder<Defaul
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public HttpSecurity addFilterAfter(Filter filter, Class<? extends Filter> afterFilter) {
|
public HttpSecurity addFilterAfter(Filter filter, Class<? extends Filter> afterFilter) {
|
||||||
this.comparator.registerAfter(filter.getClass(), afterFilter);
|
return addFilterAtOffsetOf(filter, 1, afterFilter);
|
||||||
return addFilter(filter);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public HttpSecurity addFilterBefore(Filter filter, Class<? extends Filter> beforeFilter) {
|
public HttpSecurity addFilterBefore(Filter filter, Class<? extends Filter> beforeFilter) {
|
||||||
this.comparator.registerBefore(filter.getClass(), beforeFilter);
|
return addFilterAtOffsetOf(filter, -1, beforeFilter);
|
||||||
return addFilter(filter);
|
}
|
||||||
|
|
||||||
|
private HttpSecurity addFilterAtOffsetOf(Filter filter, int offset, Class<? extends Filter> registeredFilter) {
|
||||||
|
int order = this.filterOrders.getOrder(registeredFilter) + offset;
|
||||||
|
this.filters.add(new OrderedFilter(filter, order));
|
||||||
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public HttpSecurity addFilter(Filter filter) {
|
public HttpSecurity addFilter(Filter filter) {
|
||||||
Class<? extends Filter> filterClass = filter.getClass();
|
Integer order = this.filterOrders.getOrder(filter.getClass());
|
||||||
if (!this.comparator.isRegistered(filterClass)) {
|
if (order == null) {
|
||||||
throw new IllegalArgumentException("The Filter class " + filterClass.getName()
|
throw new IllegalArgumentException("The Filter class " + filter.getClass().getName()
|
||||||
+ " does not have a registered order and cannot be added without a specified order. Consider using addFilterBefore or addFilterAfter instead.");
|
+ " does not have a registered order and cannot be added without a specified order. Consider using addFilterBefore or addFilterAfter instead.");
|
||||||
}
|
}
|
||||||
this.filters.add(filter);
|
this.filters.add(new OrderedFilter(filter, order));
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2671,8 +2686,7 @@ public final class HttpSecurity extends AbstractConfiguredSecurityBuilder<Defaul
|
|||||||
* @return the {@link HttpSecurity} for further customizations
|
* @return the {@link HttpSecurity} for further customizations
|
||||||
*/
|
*/
|
||||||
public HttpSecurity addFilterAt(Filter filter, Class<? extends Filter> atFilter) {
|
public HttpSecurity addFilterAt(Filter filter, Class<? extends Filter> atFilter) {
|
||||||
this.comparator.registerAt(filter.getClass(), atFilter);
|
return addFilterAtOffsetOf(filter, 0, atFilter);
|
||||||
return addFilter(filter);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -3060,4 +3074,37 @@ public final class HttpSecurity extends AbstractConfiguredSecurityBuilder<Defaul
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A Filter that implements Ordered to be sorted. After sorting occurs, the original
|
||||||
|
* filter is what is used by FilterChainProxy
|
||||||
|
*/
|
||||||
|
private static final class OrderedFilter implements Ordered, Filter {
|
||||||
|
|
||||||
|
private final Filter filter;
|
||||||
|
|
||||||
|
private final int order;
|
||||||
|
|
||||||
|
private OrderedFilter(Filter filter, int order) {
|
||||||
|
this.filter = filter;
|
||||||
|
this.order = order;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
|
||||||
|
throws IOException, ServletException {
|
||||||
|
this.filter.doFilter(servletRequest, servletResponse, filterChain);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getOrder() {
|
||||||
|
return this.order;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "OrderedFilter{" + "filter=" + this.filter + ", order=" + this.order + '}';
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,132 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2002-2020 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.springframework.security.config.annotation.web.builders;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import javax.servlet.Filter;
|
||||||
|
import javax.servlet.FilterChain;
|
||||||
|
import javax.servlet.ServletException;
|
||||||
|
import javax.servlet.ServletRequest;
|
||||||
|
import javax.servlet.ServletResponse;
|
||||||
|
|
||||||
|
import org.assertj.core.api.ListAssert;
|
||||||
|
import org.junit.Rule;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
|
||||||
|
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
|
||||||
|
import org.springframework.security.config.test.SpringTestRule;
|
||||||
|
import org.springframework.security.web.FilterChainProxy;
|
||||||
|
import org.springframework.security.web.access.ExceptionTranslationFilter;
|
||||||
|
import org.springframework.security.web.access.channel.ChannelProcessingFilter;
|
||||||
|
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;
|
||||||
|
import org.springframework.security.web.context.request.async.WebAsyncManagerIntegrationFilter;
|
||||||
|
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
|
||||||
|
public class HttpSecurityAddFilterTest {
|
||||||
|
|
||||||
|
@Rule
|
||||||
|
public final SpringTestRule spring = new SpringTestRule();
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void addFilterAfterWhenSameFilterDifferentPlacesThenOrderCorrect() {
|
||||||
|
this.spring.register(MyFilterMultipleAfterConfig.class).autowire();
|
||||||
|
|
||||||
|
assertThatFilters().containsSubsequence(WebAsyncManagerIntegrationFilter.class, MyFilter.class,
|
||||||
|
ExceptionTranslationFilter.class, MyFilter.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void addFilterBeforeWhenSameFilterDifferentPlacesThenOrderCorrect() {
|
||||||
|
this.spring.register(MyFilterMultipleBeforeConfig.class).autowire();
|
||||||
|
|
||||||
|
assertThatFilters().containsSubsequence(MyFilter.class, WebAsyncManagerIntegrationFilter.class, MyFilter.class,
|
||||||
|
ExceptionTranslationFilter.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void addFilterAtWhenSameFilterDifferentPlacesThenOrderCorrect() {
|
||||||
|
this.spring.register(MyFilterMultipleAtConfig.class).autowire();
|
||||||
|
|
||||||
|
assertThatFilters().containsSubsequence(MyFilter.class, WebAsyncManagerIntegrationFilter.class, MyFilter.class,
|
||||||
|
ExceptionTranslationFilter.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
private ListAssert<Class<?>> assertThatFilters() {
|
||||||
|
FilterChainProxy filterChain = this.spring.getContext().getBean(FilterChainProxy.class);
|
||||||
|
List<Class<?>> filters = filterChain.getFilters("/").stream().map(Object::getClass)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
return assertThat(filters);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class MyFilter implements Filter {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
|
||||||
|
throws IOException, ServletException {
|
||||||
|
filterChain.doFilter(servletRequest, servletResponse);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@EnableWebSecurity
|
||||||
|
static class MyFilterMultipleAfterConfig extends WebSecurityConfigurerAdapter {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void configure(HttpSecurity http) throws Exception {
|
||||||
|
// @formatter:off
|
||||||
|
http
|
||||||
|
.addFilterAfter(new MyFilter(), WebAsyncManagerIntegrationFilter.class)
|
||||||
|
.addFilterAfter(new MyFilter(), ExceptionTranslationFilter.class);
|
||||||
|
// @formatter:on
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@EnableWebSecurity
|
||||||
|
static class MyFilterMultipleBeforeConfig extends WebSecurityConfigurerAdapter {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void configure(HttpSecurity http) throws Exception {
|
||||||
|
// @formatter:off
|
||||||
|
http
|
||||||
|
.addFilterBefore(new MyFilter(), WebAsyncManagerIntegrationFilter.class)
|
||||||
|
.addFilterBefore(new MyFilter(), ExceptionTranslationFilter.class);
|
||||||
|
// @formatter:on
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@EnableWebSecurity
|
||||||
|
static class MyFilterMultipleAtConfig extends WebSecurityConfigurerAdapter {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void configure(HttpSecurity http) throws Exception {
|
||||||
|
// @formatter:off
|
||||||
|
http
|
||||||
|
.addFilterAt(new MyFilter(), ChannelProcessingFilter.class)
|
||||||
|
.addFilterAt(new MyFilter(), UsernamePasswordAuthenticationFilter.class);
|
||||||
|
// @formatter:on
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user