Ignore Unmappable Servlets

Closes gh-13666
This commit is contained in:
Josh Cummings 2023-08-18 15:11:33 -06:00
parent 7200f76ac1
commit ed96e2cddf
No known key found for this signature in database
GPG Key ID: A306A51F43B8E5A5
3 changed files with 34 additions and 7 deletions

View File

@ -18,6 +18,7 @@ package org.springframework.security.config.annotation.web;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -312,8 +313,8 @@ public abstract class AbstractRequestMatcherRegistry<C> {
if (servletContext == null) { if (servletContext == null) {
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns)); return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
} }
Map<String, ? extends ServletRegistration> registrations = servletContext.getServletRegistrations(); Map<String, ? extends ServletRegistration> registrations = mappableServletRegistrations(servletContext);
if (registrations == null) { if (registrations.isEmpty()) {
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns)); return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
} }
if (!hasDispatcherServlet(registrations)) { if (!hasDispatcherServlet(registrations)) {
@ -324,6 +325,16 @@ public abstract class AbstractRequestMatcherRegistry<C> {
return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0])); return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0]));
} }
private Map<String, ? extends ServletRegistration> mappableServletRegistrations(ServletContext servletContext) {
Map<String, ServletRegistration> mappable = new LinkedHashMap<>();
for (Map.Entry<String, ? extends ServletRegistration> entry : servletContext.getServletRegistrations().entrySet()) {
if (!entry.getValue().getMappings().isEmpty()) {
mappable.put(entry.getKey(), entry.getValue());
}
}
return mappable;
}
private boolean hasDispatcherServlet(Map<String, ? extends ServletRegistration> registrations) { private boolean hasDispatcherServlet(Map<String, ? extends ServletRegistration> registrations) {
if (registrations == null) { if (registrations == null) {
return false; return false;

View File

@ -16,8 +16,10 @@
package org.springframework.security.config; package org.springframework.security.config;
import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
@ -35,7 +37,7 @@ public class MockServletContext extends org.springframework.mock.web.MockServlet
public static MockServletContext mvc() { public static MockServletContext mvc() {
MockServletContext servletContext = new MockServletContext(); MockServletContext servletContext = new MockServletContext();
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class); servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/");
return servletContext; return servletContext;
} }
@ -59,6 +61,8 @@ public class MockServletContext extends org.springframework.mock.web.MockServlet
private final Class<?> clazz; private final Class<?> clazz;
private final Set<String> mappings = new LinkedHashSet<>();
MockServletRegistration(String name, Class<?> clazz) { MockServletRegistration(String name, Class<?> clazz) {
this.name = name; this.name = name;
this.clazz = clazz; this.clazz = clazz;
@ -91,12 +95,13 @@ public class MockServletContext extends org.springframework.mock.web.MockServlet
@Override @Override
public Set<String> addMapping(String... urlPatterns) { public Set<String> addMapping(String... urlPatterns) {
return null; this.mappings.addAll(Arrays.asList(urlPatterns));
return this.mappings;
} }
@Override @Override
public Collection<String> getMappings() { public Collection<String> getMappings() {
return null; return this.mappings;
} }
@Override @Override

View File

@ -211,12 +211,23 @@ public class AbstractRequestMatcherRegistryTests {
public void requestMatchersWhenAmbiguousServletsThenException() { public void requestMatchersWhenAmbiguousServletsThenException() {
MockServletContext servletContext = new MockServletContext(); MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext); given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class); servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/");
servletContext.addServlet("servletTwo", Servlet.class); servletContext.addServlet("servletTwo", Servlet.class).addMapping("/servlet/**");
assertThatExceptionOfType(IllegalArgumentException.class) assertThatExceptionOfType(IllegalArgumentException.class)
.isThrownBy(() -> this.matcherRegistry.requestMatchers("/**")); .isThrownBy(() -> this.matcherRegistry.requestMatchers("/**"));
} }
@Test
public void requestMatchersWhenUnmappableServletsThenSkips() {
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/");
servletContext.addServlet("servletTwo", Servlet.class);
List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/**");
assertThat(requestMatchers).hasSize(1);
assertThat(requestMatchers.get(0)).isInstanceOf(MvcRequestMatcher.class);
}
private void mockMvcIntrospector(boolean isPresent) { private void mockMvcIntrospector(boolean isPresent) {
ApplicationContext context = this.matcherRegistry.getApplicationContext(); ApplicationContext context = this.matcherRegistry.getApplicationContext();
given(context.containsBean("mvcHandlerMappingIntrospector")).willReturn(isPresent); given(context.containsBean("mvcHandlerMappingIntrospector")).willReturn(isPresent);