Improve RequestMatcher Validation

Closes gh-13551
This commit is contained in:
Josh Cummings 2023-07-05 11:45:20 -06:00
parent a939f17890
commit df239b6448
No known key found for this signature in database
GPG Key ID: A306A51F43B8E5A5
3 changed files with 148 additions and 17 deletions

View File

@ -19,8 +19,11 @@ package org.springframework.security.config.annotation.web;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Map;
import javax.servlet.DispatcherType; import javax.servlet.DispatcherType;
import javax.servlet.ServletContext;
import javax.servlet.ServletRegistration;
import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContext;
@ -36,6 +39,7 @@ import org.springframework.security.web.util.matcher.RegexRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.servlet.handler.HandlerMappingIntrospector; import org.springframework.web.servlet.handler.HandlerMappingIntrospector;
/** /**
@ -297,14 +301,47 @@ public abstract class AbstractRequestMatcherRegistry<C> {
* @since 5.8 * @since 5.8
*/ */
public C requestMatchers(HttpMethod method, String... patterns) { public C requestMatchers(HttpMethod method, String... patterns) {
List<RequestMatcher> matchers = new ArrayList<>(); if (!mvcPresent) {
if (mvcPresent) { return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
matchers.addAll(createMvcMatchers(method, patterns));
} }
else { if (!(this.context instanceof WebApplicationContext)) {
matchers.addAll(RequestMatchers.antMatchers(method, patterns)); return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
} }
return requestMatchers(matchers.toArray(new RequestMatcher[0])); WebApplicationContext context = (WebApplicationContext) this.context;
ServletContext servletContext = context.getServletContext();
if (servletContext == null) {
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
}
Map<String, ? extends ServletRegistration> registrations = servletContext.getServletRegistrations();
if (registrations == null) {
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
}
if (!hasDispatcherServlet(registrations)) {
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
}
Assert.isTrue(registrations.size() == 1,
"This method cannot decide whether these patterns are Spring MVC patterns or not. If this endpoint is a Spring MVC endpoint, please use requestMatchers(MvcRequestMatcher); otherwise, please use requestMatchers(AntPathRequestMatcher).");
return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0]));
}
private boolean hasDispatcherServlet(Map<String, ? extends ServletRegistration> registrations) {
if (registrations == null) {
return false;
}
Class<?> dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet",
null);
for (ServletRegistration registration : registrations.values()) {
try {
Class<?> clazz = Class.forName(registration.getClassName());
if (dispatcherServlet.isAssignableFrom(clazz)) {
return true;
}
}
catch (ClassNotFoundException ex) {
return false;
}
}
return false;
} }
/** /**
@ -380,12 +417,7 @@ public abstract class AbstractRequestMatcherRegistry<C> {
* @return a {@link List} of {@link AntPathRequestMatcher} instances * @return a {@link List} of {@link AntPathRequestMatcher} instances
*/ */
static List<RequestMatcher> antMatchers(HttpMethod httpMethod, String... antPatterns) { static List<RequestMatcher> antMatchers(HttpMethod httpMethod, String... antPatterns) {
String method = (httpMethod != null) ? httpMethod.toString() : null; return Arrays.asList(antMatchersAsArray(httpMethod, antPatterns));
List<RequestMatcher> matchers = new ArrayList<>();
for (String pattern : antPatterns) {
matchers.add(new AntPathRequestMatcher(pattern, method));
}
return matchers;
} }
/** /**
@ -399,6 +431,15 @@ public abstract class AbstractRequestMatcherRegistry<C> {
return antMatchers(null, antPatterns); return antMatchers(null, antPatterns);
} }
static RequestMatcher[] antMatchersAsArray(HttpMethod httpMethod, String... antPatterns) {
String method = (httpMethod != null) ? httpMethod.toString() : null;
RequestMatcher[] matchers = new RequestMatcher[antPatterns.length];
for (int index = 0; index < antPatterns.length; index++) {
matchers[index] = new AntPathRequestMatcher(antPatterns[index], method);
}
return matchers;
}
/** /**
* Create a {@link List} of {@link RegexRequestMatcher} instances. * Create a {@link List} of {@link RegexRequestMatcher} instances.
* @param httpMethod the {@link HttpMethod} to use or {@code null} for any * @param httpMethod the {@link HttpMethod} to use or {@code null} for any

View File

@ -18,10 +18,16 @@ package org.springframework.security.config.annotation.web;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.lang.reflect.Modifier; import java.lang.reflect.Modifier;
import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import javax.servlet.DispatcherType; import javax.servlet.DispatcherType;
import javax.servlet.Servlet;
import javax.servlet.ServletContext;
import javax.servlet.ServletRegistration;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -34,6 +40,8 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.DispatcherTypeRequestMatcher; import org.springframework.security.web.util.matcher.DispatcherTypeRequestMatcher;
import org.springframework.security.web.util.matcher.RegexRequestMatcher; import org.springframework.security.web.util.matcher.RegexRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.servlet.DispatcherServlet;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
@ -56,12 +64,17 @@ public class AbstractRequestMatcherRegistryTests {
private TestRequestMatcherRegistry matcherRegistry; private TestRequestMatcherRegistry matcherRegistry;
private WebApplicationContext context;
@BeforeEach @BeforeEach
public void setUp() { public void setUp() {
this.matcherRegistry = new TestRequestMatcherRegistry(); this.matcherRegistry = new TestRequestMatcherRegistry();
ApplicationContext context = mock(ApplicationContext.class); this.context = mock(WebApplicationContext.class);
given(context.getBean(ObjectPostProcessor.class)).willReturn(NO_OP_OBJECT_POST_PROCESSOR); ServletContext servletContext = new MockServletContext();
this.matcherRegistry.setApplicationContext(context); servletContext.addServlet("dispatcherServlet", DispatcherServlet.class);
given(this.context.getBean(ObjectPostProcessor.class)).willReturn(NO_OP_OBJECT_POST_PROCESSOR);
given(this.context.getServletContext()).willReturn(servletContext);
this.matcherRegistry.setApplicationContext(this.context);
} }
@Test @Test
@ -184,6 +197,32 @@ public class AbstractRequestMatcherRegistryTests {
"Please ensure Spring Security & Spring MVC are configured in a shared ApplicationContext"); "Please ensure Spring Security & Spring MVC are configured in a shared ApplicationContext");
} }
@Test
public void requestMatchersWhenNoDispatcherServletThenAntPathRequestMatcherType() {
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/**");
assertThat(requestMatchers).isNotEmpty();
assertThat(requestMatchers).hasSize(1);
assertThat(requestMatchers.get(0)).isExactlyInstanceOf(AntPathRequestMatcher.class);
servletContext.addServlet("servletOne", Servlet.class);
servletContext.addServlet("servletTwo", Servlet.class);
requestMatchers = this.matcherRegistry.requestMatchers("/**");
assertThat(requestMatchers).isNotEmpty();
assertThat(requestMatchers).hasSize(1);
assertThat(requestMatchers.get(0)).isExactlyInstanceOf(AntPathRequestMatcher.class);
}
@Test
public void requestMatchersWhenAmbiguousServletsThenException() {
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class);
servletContext.addServlet("servletTwo", Servlet.class);
assertThatExceptionOfType(IllegalArgumentException.class)
.isThrownBy(() -> this.matcherRegistry.requestMatchers("/**"));
}
private void mockMvcIntrospector(boolean isPresent) { private void mockMvcIntrospector(boolean isPresent) {
ApplicationContext context = this.matcherRegistry.getApplicationContext(); ApplicationContext context = this.matcherRegistry.getApplicationContext();
given(context.containsBean("mvcHandlerMappingIntrospector")).willReturn(isPresent); given(context.containsBean("mvcHandlerMappingIntrospector")).willReturn(isPresent);
@ -217,4 +256,25 @@ public class AbstractRequestMatcherRegistryTests {
} }
private static class MockServletContext extends org.springframework.mock.web.MockServletContext {
private final Map<String, ServletRegistration> registrations = new LinkedHashMap<>();
@NotNull
@Override
public ServletRegistration.Dynamic addServlet(@NotNull String servletName, Class<? extends Servlet> clazz) {
ServletRegistration.Dynamic dynamic = mock(ServletRegistration.Dynamic.class);
given(dynamic.getClassName()).willReturn(clazz.getName());
this.registrations.put(servletName, dynamic);
return dynamic;
}
@NotNull
@Override
public Map<String, ? extends ServletRegistration> getServletRegistrations() {
return this.registrations;
}
}
} }

View File

@ -18,9 +18,14 @@ package org.springframework.security.config.annotation.web.configurers;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.lang.reflect.Modifier; import java.lang.reflect.Modifier;
import java.util.LinkedHashMap;
import java.util.Map;
import javax.servlet.Servlet;
import javax.servlet.ServletRegistration;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -34,7 +39,6 @@ import org.springframework.core.annotation.Order;
import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockServletContext;
import org.springframework.security.config.annotation.web.AbstractRequestMatcherRegistry; import org.springframework.security.config.annotation.web.AbstractRequestMatcherRegistry;
import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
@ -48,12 +52,15 @@ import org.springframework.security.web.servlet.util.matcher.MvcRequestMatcher;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.servlet.config.annotation.EnableWebMvc; import org.springframework.web.servlet.config.annotation.EnableWebMvc;
import org.springframework.web.servlet.config.annotation.PathMatchConfigurer; import org.springframework.web.servlet.config.annotation.PathMatchConfigurer;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
import org.springframework.web.servlet.handler.HandlerMappingIntrospector; import org.springframework.web.servlet.handler.HandlerMappingIntrospector;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.config.Customizer.withDefaults;
/** /**
@ -233,7 +240,9 @@ public class HttpSecuritySecurityMatchersTests {
public void loadConfig(Class<?>... configs) { public void loadConfig(Class<?>... configs) {
this.context = new AnnotationConfigWebApplicationContext(); this.context = new AnnotationConfigWebApplicationContext();
this.context.register(configs); this.context.register(configs);
this.context.setServletContext(new MockServletContext()); MockServletContext servletContext = new MockServletContext();
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class);
this.context.setServletContext(servletContext);
this.context.refresh(); this.context.refresh();
this.context.getAutowireCapableBeanFactory().autowireBean(this); this.context.getAutowireCapableBeanFactory().autowireBean(this);
} }
@ -564,4 +573,25 @@ public class HttpSecuritySecurityMatchersTests {
} }
private static class MockServletContext extends org.springframework.mock.web.MockServletContext {
private final Map<String, ServletRegistration> registrations = new LinkedHashMap<>();
@NotNull
@Override
public ServletRegistration.Dynamic addServlet(@NotNull String servletName, Class<? extends Servlet> clazz) {
ServletRegistration.Dynamic dynamic = mock(ServletRegistration.Dynamic.class);
given(dynamic.getClassName()).willReturn(clazz.getName());
this.registrations.put(servletName, dynamic);
return dynamic;
}
@NotNull
@Override
public Map<String, ? extends ServletRegistration> getServletRegistrations() {
return this.registrations;
}
}
} }