Refine requestMatcher Validation Rules

Closes gh-14078
This commit is contained in:
Josh Cummings 2023-10-11 14:01:36 -06:00
parent 3f64c6d745
commit ffd12ee3b9
No known key found for this signature in database
GPG Key ID: A306A51F43B8E5A5
4 changed files with 292 additions and 17 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2022 the original author or authors. * Copyright 2002-2023 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -26,6 +26,7 @@ import java.util.Map;
import javax.servlet.DispatcherType; import javax.servlet.DispatcherType;
import javax.servlet.ServletContext; import javax.servlet.ServletContext;
import javax.servlet.ServletRegistration; import javax.servlet.ServletRegistration;
import javax.servlet.http.HttpServletRequest;
import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContext;
@ -321,12 +322,31 @@ public abstract class AbstractRequestMatcherRegistry<C> {
if (!hasDispatcherServlet(registrations)) { if (!hasDispatcherServlet(registrations)) {
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns)); return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
} }
if (registrations.size() > 1) { ServletRegistration dispatcherServlet = requireOneRootDispatcherServlet(registrations);
if (dispatcherServlet != null) {
if (registrations.size() == 1) {
return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0]));
}
List<RequestMatcher> matchers = new ArrayList<>();
for (String pattern : patterns) {
AntPathRequestMatcher ant = new AntPathRequestMatcher(pattern, (method != null) ? method.name() : null);
MvcRequestMatcher mvc = createMvcMatchers(method, pattern).get(0);
matchers.add(new DispatcherServletDelegatingRequestMatcher(ant, mvc, servletContext));
}
return requestMatchers(matchers.toArray(new RequestMatcher[0]));
}
dispatcherServlet = requireOnlyPathMappedDispatcherServlet(registrations);
if (dispatcherServlet != null) {
String mapping = dispatcherServlet.getMappings().iterator().next();
List<MvcRequestMatcher> matchers = createMvcMatchers(method, patterns);
for (MvcRequestMatcher matcher : matchers) {
matcher.setServletPath(mapping.substring(0, mapping.length() - 2));
}
return requestMatchers(matchers.toArray(new RequestMatcher[0]));
}
String errorMessage = computeErrorMessage(registrations.values()); String errorMessage = computeErrorMessage(registrations.values());
throw new IllegalArgumentException(errorMessage); throw new IllegalArgumentException(errorMessage);
} }
return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0]));
}
private Map<String, ? extends ServletRegistration> mappableServletRegistrations(ServletContext servletContext) { private Map<String, ? extends ServletRegistration> mappableServletRegistrations(ServletContext servletContext) {
Map<String, ServletRegistration> mappable = new LinkedHashMap<>(); Map<String, ServletRegistration> mappable = new LinkedHashMap<>();
@ -343,21 +363,65 @@ public abstract class AbstractRequestMatcherRegistry<C> {
if (registrations == null) { if (registrations == null) {
return false; return false;
} }
Class<?> dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet",
null);
for (ServletRegistration registration : registrations.values()) { for (ServletRegistration registration : registrations.values()) {
try { if (isDispatcherServlet(registration)) {
Class<?> clazz = Class.forName(registration.getClassName());
if (dispatcherServlet.isAssignableFrom(clazz)) {
return true; return true;
} }
} }
return false;
}
private ServletRegistration requireOneRootDispatcherServlet(
Map<String, ? extends ServletRegistration> registrations) {
ServletRegistration rootDispatcherServlet = null;
for (ServletRegistration registration : registrations.values()) {
if (!isDispatcherServlet(registration)) {
continue;
}
if (registration.getMappings().size() > 1) {
return null;
}
if (!"/".equals(registration.getMappings().iterator().next())) {
return null;
}
rootDispatcherServlet = registration;
}
return rootDispatcherServlet;
}
private ServletRegistration requireOnlyPathMappedDispatcherServlet(
Map<String, ? extends ServletRegistration> registrations) {
ServletRegistration pathDispatcherServlet = null;
for (ServletRegistration registration : registrations.values()) {
if (!isDispatcherServlet(registration)) {
return null;
}
if (registration.getMappings().size() > 1) {
return null;
}
String mapping = registration.getMappings().iterator().next();
if (!mapping.startsWith("/") || !mapping.endsWith("/*")) {
return null;
}
if (pathDispatcherServlet != null) {
return null;
}
pathDispatcherServlet = registration;
}
return pathDispatcherServlet;
}
private boolean isDispatcherServlet(ServletRegistration registration) {
Class<?> dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet",
null);
try {
Class<?> clazz = Class.forName(registration.getClassName());
return dispatcherServlet.isAssignableFrom(clazz);
}
catch (ClassNotFoundException ex) { catch (ClassNotFoundException ex) {
return false; return false;
} }
} }
return false;
}
private String computeErrorMessage(Collection<? extends ServletRegistration> registrations) { private String computeErrorMessage(Collection<? extends ServletRegistration> registrations) {
String template = "This method cannot decide whether these patterns are Spring MVC patterns or not. " String template = "This method cannot decide whether these patterns are Spring MVC patterns or not. "
@ -498,4 +562,55 @@ public abstract class AbstractRequestMatcherRegistry<C> {
} }
static class DispatcherServletDelegatingRequestMatcher implements RequestMatcher {
private final AntPathRequestMatcher ant;
private final MvcRequestMatcher mvc;
private final ServletContext servletContext;
DispatcherServletDelegatingRequestMatcher(AntPathRequestMatcher ant, MvcRequestMatcher mvc,
ServletContext servletContext) {
this.ant = ant;
this.mvc = mvc;
this.servletContext = servletContext;
}
@Override
public boolean matches(HttpServletRequest request) {
String name = request.getHttpServletMapping().getServletName();
ServletRegistration registration = this.servletContext.getServletRegistration(name);
Assert.notNull(registration, "Failed to find servlet [" + name + "] in the servlet context");
if (isDispatcherServlet(registration)) {
return this.mvc.matches(request);
}
return this.ant.matches(request);
}
@Override
public MatchResult matcher(HttpServletRequest request) {
String name = request.getHttpServletMapping().getServletName();
ServletRegistration registration = this.servletContext.getServletRegistration(name);
Assert.notNull(registration, "Failed to find servlet [" + name + "] in the servlet context");
if (isDispatcherServlet(registration)) {
return this.mvc.matcher(request);
}
return this.ant.matcher(request);
}
private boolean isDispatcherServlet(ServletRegistration registration) {
Class<?> dispatcherServlet = ClassUtils
.resolveClassName("org.springframework.web.servlet.DispatcherServlet", null);
try {
Class<?> clazz = Class.forName(registration.getClassName());
return dispatcherServlet.isAssignableFrom(clazz);
}
catch (ClassNotFoundException ex) {
return false;
}
}
}
} }

View File

@ -55,6 +55,11 @@ public class MockServletContext extends org.springframework.mock.web.MockServlet
return this.registrations; return this.registrations;
} }
@Override
public ServletRegistration getServletRegistration(String servletName) {
return this.registrations.get(servletName);
}
private static class MockServletRegistration implements ServletRegistration.Dynamic { private static class MockServletRegistration implements ServletRegistration.Dynamic {
private final String name; private final String name;

View File

@ -0,0 +1,46 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.config;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.MappingMatch;
import org.springframework.mock.web.MockHttpServletMapping;
public final class TestMockHttpServletMappings {
private TestMockHttpServletMappings() {
}
public static MockHttpServletMapping extension(HttpServletRequest request, String extension) {
String uri = request.getRequestURI();
String matchValue = uri.substring(0, uri.lastIndexOf(extension));
return new MockHttpServletMapping(matchValue, "*" + extension, "extension", MappingMatch.EXTENSION);
}
public static MockHttpServletMapping path(HttpServletRequest request, String path) {
String uri = request.getRequestURI();
String matchValue = uri.substring(path.length());
return new MockHttpServletMapping(matchValue, path + "/*", "path", MappingMatch.PATH);
}
public static MockHttpServletMapping defaultMapping() {
return new MockHttpServletMapping("", "/", "default", MappingMatch.DEFAULT);
}
}

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2022 the original author or authors. * Copyright 2002-2023 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -22,6 +22,7 @@ import java.util.List;
import javax.servlet.DispatcherType; import javax.servlet.DispatcherType;
import javax.servlet.Servlet; import javax.servlet.Servlet;
import javax.servlet.http.HttpServletMapping;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -29,8 +30,11 @@ import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContext;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.config.MockServletContext; import org.springframework.security.config.MockServletContext;
import org.springframework.security.config.TestMockHttpServletMappings;
import org.springframework.security.config.annotation.ObjectPostProcessor; import org.springframework.security.config.annotation.ObjectPostProcessor;
import org.springframework.security.config.annotation.web.AbstractRequestMatcherRegistry.DispatcherServletDelegatingRequestMatcher;
import org.springframework.security.web.servlet.util.matcher.MvcRequestMatcher; import org.springframework.security.web.servlet.util.matcher.MvcRequestMatcher;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.DispatcherTypeRequestMatcher; import org.springframework.security.web.util.matcher.DispatcherTypeRequestMatcher;
@ -43,6 +47,9 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions;
/** /**
* Tests for {@link AbstractRequestMatcherRegistry}. * Tests for {@link AbstractRequestMatcherRegistry}.
@ -197,6 +204,8 @@ public class AbstractRequestMatcherRegistryTests {
public void requestMatchersWhenNoDispatcherServletThenAntPathRequestMatcherType() { public void requestMatchersWhenNoDispatcherServletThenAntPathRequestMatcherType() {
MockServletContext servletContext = new MockServletContext(); MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext); given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("servletOne", Servlet.class).addMapping("/one");
servletContext.addServlet("servletTwo", Servlet.class).addMapping("/two");
List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/**"); List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/**");
assertThat(requestMatchers).isNotEmpty(); assertThat(requestMatchers).isNotEmpty();
assertThat(requestMatchers).hasSize(1); assertThat(requestMatchers).hasSize(1);
@ -214,7 +223,26 @@ public class AbstractRequestMatcherRegistryTests {
MockServletContext servletContext = new MockServletContext(); MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext); given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/"); servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/");
servletContext.addServlet("servletTwo", Servlet.class).addMapping("/servlet/**"); servletContext.addServlet("servletTwo", DispatcherServlet.class).addMapping("/servlet/*");
assertThatExceptionOfType(IllegalArgumentException.class)
.isThrownBy(() -> this.matcherRegistry.requestMatchers("/**"));
}
@Test
public void requestMatchersWhenMultipleDispatcherServletMappingsThenException() {
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/", "/mvc/*");
assertThatExceptionOfType(IllegalArgumentException.class)
.isThrownBy(() -> this.matcherRegistry.requestMatchers("/**"));
}
@Test
public void requestMatchersWhenPathDispatcherServletAndOtherServletsThenException() {
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/mvc/*");
servletContext.addServlet("default", Servlet.class).addMapping("/");
assertThatExceptionOfType(IllegalArgumentException.class) assertThatExceptionOfType(IllegalArgumentException.class)
.isThrownBy(() -> this.matcherRegistry.requestMatchers("/**")); .isThrownBy(() -> this.matcherRegistry.requestMatchers("/**"));
} }
@ -231,6 +259,87 @@ public class AbstractRequestMatcherRegistryTests {
assertThat(requestMatchers.get(0)).isInstanceOf(MvcRequestMatcher.class); assertThat(requestMatchers.get(0)).isInstanceOf(MvcRequestMatcher.class);
} }
@Test
public void requestMatchersWhenOnlyDispatcherServletThenAllows() {
mockMvcIntrospector(true);
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/mvc/*");
List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/**");
assertThat(requestMatchers).hasSize(1);
assertThat(requestMatchers.get(0)).isInstanceOf(MvcRequestMatcher.class);
}
@Test
public void requestMatchersWhenImplicitServletsThenAllows() {
mockMvcIntrospector(true);
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("defaultServlet", Servlet.class);
servletContext.addServlet("jspServlet", Servlet.class).addMapping("*.jsp", "*.jspx");
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/");
List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/**");
assertThat(requestMatchers).hasSize(1);
assertThat(requestMatchers.get(0)).isInstanceOf(DispatcherServletDelegatingRequestMatcher.class);
}
@Test
public void requestMatchersWhenPathBasedNonDispatcherServletThenAllows() {
mockMvcIntrospector(true);
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("path", Servlet.class).addMapping("/services/*");
servletContext.addServlet("default", DispatcherServlet.class).addMapping("/");
List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/services/*");
assertThat(requestMatchers).hasSize(1);
assertThat(requestMatchers.get(0)).isInstanceOf(DispatcherServletDelegatingRequestMatcher.class);
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/services/endpoint") {
@Override
public HttpServletMapping getHttpServletMapping() {
return TestMockHttpServletMappings.defaultMapping();
}
};
assertThat(requestMatchers.get(0).matcher(request).isMatch()).isTrue();
request = new MockHttpServletRequest("GET", "/services/endpoint") {
@Override
public HttpServletMapping getHttpServletMapping() {
return TestMockHttpServletMappings.path(this, "/services");
}
};
request.setServletPath("/services");
request.setPathInfo("/endpoint");
assertThat(requestMatchers.get(0).matcher(request).isMatch()).isTrue();
}
@Test
public void matchesWhenDispatcherServletThenMvc() {
MockServletContext servletContext = new MockServletContext();
servletContext.addServlet("default", DispatcherServlet.class).addMapping("/");
servletContext.addServlet("path", Servlet.class).addMapping("/services/*");
MvcRequestMatcher mvc = mock(MvcRequestMatcher.class);
AntPathRequestMatcher ant = mock(AntPathRequestMatcher.class);
DispatcherServletDelegatingRequestMatcher requestMatcher = new DispatcherServletDelegatingRequestMatcher(ant,
mvc, servletContext);
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/services/endpoint") {
@Override
public HttpServletMapping getHttpServletMapping() {
return TestMockHttpServletMappings.defaultMapping();
}
};
assertThat(requestMatcher.matches(request)).isFalse();
verify(mvc).matches(request);
verifyNoInteractions(ant);
request = new MockHttpServletRequest("GET", "/services/endpoint") {
@Override
public HttpServletMapping getHttpServletMapping() {
return TestMockHttpServletMappings.path(this, "/services");
}
};
assertThat(requestMatcher.matches(request)).isFalse();
verify(ant).matches(request);
verifyNoMoreInteractions(mvc);
}
private void mockMvcIntrospector(boolean isPresent) { private void mockMvcIntrospector(boolean isPresent) {
ApplicationContext context = this.matcherRegistry.getApplicationContext(); ApplicationContext context = this.matcherRegistry.getApplicationContext();
given(context.containsBean("mvcHandlerMappingIntrospector")).willReturn(isPresent); given(context.containsBean("mvcHandlerMappingIntrospector")).willReturn(isPresent);