Merge branch '6.1.x'

This commit is contained in:
Josh Cummings 2023-11-17 12:01:57 -07:00
commit 5958828113
No known key found for this signature in database
GPG Key ID: A306A51F43B8E5A5
2 changed files with 95 additions and 15 deletions

View File

@ -22,6 +22,8 @@ import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.ServletContext;
@ -42,6 +44,7 @@ import org.springframework.security.web.util.matcher.RegexRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.function.SingletonSupplier;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.servlet.handler.HandlerMappingIntrospector;
@ -197,34 +200,51 @@ public abstract class AbstractRequestMatcherRegistry<C> {
if (servletContext == null) {
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
}
boolean isProgrammaticApiAvailable = isProgrammaticApiAvailable(servletContext);
List<RequestMatcher> matchers = new ArrayList<>();
for (String pattern : patterns) {
AntPathRequestMatcher ant = new AntPathRequestMatcher(pattern, (method != null) ? method.name() : null);
MvcRequestMatcher mvc = createMvcMatchers(method, pattern).get(0);
if (isProgrammaticApiAvailable) {
matchers.add(resolve(ant, mvc, servletContext));
}
else {
matchers.add(new DeferredRequestMatcher(() -> resolve(ant, mvc, servletContext), mvc, ant));
}
}
return requestMatchers(matchers.toArray(new RequestMatcher[0]));
}
private static boolean isProgrammaticApiAvailable(ServletContext servletContext) {
try {
servletContext.getServletRegistrations();
return true;
}
catch (UnsupportedOperationException ex) {
return false;
}
}
private RequestMatcher resolve(AntPathRequestMatcher ant, MvcRequestMatcher mvc, ServletContext servletContext) {
Map<String, ? extends ServletRegistration> registrations = mappableServletRegistrations(servletContext);
if (registrations.isEmpty()) {
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
return ant;
}
if (!hasDispatcherServlet(registrations)) {
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
return ant;
}
ServletRegistration dispatcherServlet = requireOneRootDispatcherServlet(registrations);
if (dispatcherServlet != null) {
if (registrations.size() == 1) {
return requestMatchers(createMvcMatchers(method, patterns).toArray(RequestMatcher[]::new));
return mvc;
}
List<RequestMatcher> matchers = new ArrayList<>();
for (String pattern : patterns) {
AntPathRequestMatcher ant = new AntPathRequestMatcher(pattern, (method != null) ? method.name() : null);
MvcRequestMatcher mvc = createMvcMatchers(method, pattern).get(0);
matchers.add(new DispatcherServletDelegatingRequestMatcher(ant, mvc, servletContext));
}
return requestMatchers(matchers.toArray(new RequestMatcher[0]));
return new DispatcherServletDelegatingRequestMatcher(ant, mvc, servletContext);
}
dispatcherServlet = requireOnlyPathMappedDispatcherServlet(registrations);
if (dispatcherServlet != null) {
String mapping = dispatcherServlet.getMappings().iterator().next();
List<MvcRequestMatcher> matchers = createMvcMatchers(method, patterns);
for (MvcRequestMatcher matcher : matchers) {
matcher.setServletPath(mapping.substring(0, mapping.length() - 2));
}
return requestMatchers(matchers.toArray(new RequestMatcher[0]));
mvc.setServletPath(mapping.substring(0, mapping.length() - 2));
return mvc;
}
String errorMessage = computeErrorMessage(registrations.values());
throw new IllegalArgumentException(errorMessage);
@ -444,6 +464,38 @@ public abstract class AbstractRequestMatcherRegistry<C> {
}
static class DeferredRequestMatcher implements RequestMatcher {
final Supplier<RequestMatcher> requestMatcher;
final AtomicReference<String> description = new AtomicReference<>();
DeferredRequestMatcher(Supplier<RequestMatcher> resolver, RequestMatcher... candidates) {
this.requestMatcher = SingletonSupplier.of(() -> {
RequestMatcher matcher = resolver.get();
this.description.set(matcher.toString());
return matcher;
});
this.description.set("Deferred " + candidates);
}
@Override
public boolean matches(HttpServletRequest request) {
return this.requestMatcher.get().matches(request);
}
@Override
public MatchResult matcher(HttpServletRequest request) {
return this.requestMatcher.get().matcher(request);
}
@Override
public String toString() {
return this.description.get();
}
}
static class DispatcherServletDelegatingRequestMatcher implements RequestMatcher {
private final AntPathRequestMatcher ant;
@ -493,6 +545,11 @@ public abstract class AbstractRequestMatcherRegistry<C> {
}
}
@Override
public String toString() {
return "DispatcherServletDelegating [" + "ant = " + this.ant + ", mvc = " + this.mvc + "]";
}
}
}

View File

@ -16,6 +16,7 @@
package org.springframework.security.config.annotation.web;
import java.util.ArrayList;
import java.util.List;
import jakarta.servlet.DispatcherType;
@ -163,6 +164,7 @@ public class AbstractRequestMatcherRegistryTests {
@Test
public void requestMatchersWhenNoDispatcherServletThenAntPathRequestMatcherType() {
mockMvcIntrospector(true);
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("servletOne", Servlet.class).addMapping("/one");
@ -175,6 +177,7 @@ public class AbstractRequestMatcherRegistryTests {
@Test
public void requestMatchersWhenAmbiguousServletsThenException() {
mockMvcIntrospector(true);
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/");
@ -185,6 +188,7 @@ public class AbstractRequestMatcherRegistryTests {
@Test
public void requestMatchersWhenMultipleDispatcherServletMappingsThenException() {
mockMvcIntrospector(true);
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/", "/mvc/*");
@ -194,6 +198,7 @@ public class AbstractRequestMatcherRegistryTests {
@Test
public void requestMatchersWhenPathDispatcherServletAndOtherServletsThenException() {
mockMvcIntrospector(true);
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/mvc/*");
@ -282,11 +287,29 @@ public class AbstractRequestMatcherRegistryTests {
private static class TestRequestMatcherRegistry extends AbstractRequestMatcherRegistry<List<RequestMatcher>> {
@Override
public List<RequestMatcher> requestMatchers(RequestMatcher... requestMatchers) {
return unwrap(super.requestMatchers(requestMatchers));
}
@Override
protected List<RequestMatcher> chainRequestMatchers(List<RequestMatcher> requestMatchers) {
return requestMatchers;
}
private static List<RequestMatcher> unwrap(List<RequestMatcher> wrappedMatchers) {
List<RequestMatcher> requestMatchers = new ArrayList<>();
for (RequestMatcher requestMatcher : wrappedMatchers) {
if (requestMatcher instanceof AbstractRequestMatcherRegistry.DeferredRequestMatcher) {
requestMatchers.add(((DeferredRequestMatcher) requestMatcher).requestMatcher.get());
}
else {
requestMatchers.add(requestMatcher);
}
}
return requestMatchers;
}
}
}