diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java b/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java index d94e9d9083..e29a4d599a 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java @@ -40,6 +40,7 @@ import org.springframework.core.ResolvableType; import org.springframework.http.HttpMethod; import org.springframework.lang.Nullable; import org.springframework.security.config.ObjectPostProcessor; +import org.springframework.security.config.annotation.web.ServletRegistrationsSupport.RegistrationMapping; import org.springframework.security.config.annotation.web.configurers.AbstractConfigAttributeRequestMatcherRegistry; import org.springframework.security.web.servlet.util.matcher.MvcRequestMatcher; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; @@ -235,103 +236,31 @@ public abstract class AbstractRequestMatcherRegistry { } private RequestMatcher resolve(AntPathRequestMatcher ant, MvcRequestMatcher mvc, ServletContext servletContext) { - Map registrations = mappableServletRegistrations(servletContext); - if (registrations.isEmpty()) { + ServletRegistrationsSupport registrations = new ServletRegistrationsSupport(servletContext); + Collection mappings = registrations.mappings(); + if (mappings.isEmpty()) { return new DispatcherServletDelegatingRequestMatcher(ant, mvc, new MockMvcRequestMatcher()); } - if (!hasDispatcherServlet(registrations)) { + Collection dispatcherServletMappings = registrations.dispatcherServletMappings(); + if (dispatcherServletMappings.isEmpty()) { return new DispatcherServletDelegatingRequestMatcher(ant, mvc, new MockMvcRequestMatcher()); } - ServletRegistration dispatcherServlet = requireOneRootDispatcherServlet(registrations); - if (dispatcherServlet != null) { - if (registrations.size() == 1) { + if (dispatcherServletMappings.size() > 1) { + String errorMessage = computeErrorMessage(servletContext.getServletRegistrations().values()); + throw new IllegalArgumentException(errorMessage); + } + RegistrationMapping dispatcherServlet = dispatcherServletMappings.iterator().next(); + if (mappings.size() > 1 && !dispatcherServlet.isDefault()) { + String errorMessage = computeErrorMessage(servletContext.getServletRegistrations().values()); + throw new IllegalArgumentException(errorMessage); + } + if (dispatcherServlet.isDefault()) { + if (mappings.size() == 1) { return mvc; } - return new DispatcherServletDelegatingRequestMatcher(ant, mvc, servletContext); - } - dispatcherServlet = requireOnlyPathMappedDispatcherServlet(registrations); - if (dispatcherServlet != null) { - String mapping = dispatcherServlet.getMappings().iterator().next(); - mvc.setServletPath(mapping.substring(0, mapping.length() - 2)); - return mvc; - } - String errorMessage = computeErrorMessage(registrations.values()); - throw new IllegalArgumentException(errorMessage); - } - - private Map mappableServletRegistrations(ServletContext servletContext) { - Map mappable = new LinkedHashMap<>(); - for (Map.Entry entry : servletContext.getServletRegistrations() - .entrySet()) { - if (!entry.getValue().getMappings().isEmpty()) { - mappable.put(entry.getKey(), entry.getValue()); - } - } - return mappable; - } - - private boolean hasDispatcherServlet(Map registrations) { - if (registrations == null) { - return false; - } - for (ServletRegistration registration : registrations.values()) { - if (isDispatcherServlet(registration)) { - return true; - } - } - return false; - } - - private ServletRegistration requireOneRootDispatcherServlet( - Map registrations) { - ServletRegistration rootDispatcherServlet = null; - for (ServletRegistration registration : registrations.values()) { - if (!isDispatcherServlet(registration)) { - continue; - } - if (registration.getMappings().size() > 1) { - return null; - } - if (!"/".equals(registration.getMappings().iterator().next())) { - return null; - } - rootDispatcherServlet = registration; - } - return rootDispatcherServlet; - } - - private ServletRegistration requireOnlyPathMappedDispatcherServlet( - Map registrations) { - ServletRegistration pathDispatcherServlet = null; - for (ServletRegistration registration : registrations.values()) { - if (!isDispatcherServlet(registration)) { - return null; - } - if (registration.getMappings().size() > 1) { - return null; - } - String mapping = registration.getMappings().iterator().next(); - if (!mapping.startsWith("/") || !mapping.endsWith("/*")) { - return null; - } - if (pathDispatcherServlet != null) { - return null; - } - pathDispatcherServlet = registration; - } - return pathDispatcherServlet; - } - - private boolean isDispatcherServlet(ServletRegistration registration) { - Class dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet", - null); - try { - Class clazz = Class.forName(registration.getClassName()); - return dispatcherServlet.isAssignableFrom(clazz); - } - catch (ClassNotFoundException ex) { - return false; + return new DispatcherServletDelegatingRequestMatcher(ant, mvc); } + return mvc; } private static String computeErrorMessage(Collection registrations) { diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/ServletRegistrationsSupport.java b/config/src/main/java/org/springframework/security/config/annotation/web/ServletRegistrationsSupport.java new file mode 100644 index 0000000000..e84b8455f1 --- /dev/null +++ b/config/src/main/java/org/springframework/security/config/annotation/web/ServletRegistrationsSupport.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.web; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Map; + +import jakarta.servlet.ServletContext; +import jakarta.servlet.ServletRegistration; + +import org.springframework.util.ClassUtils; + +class ServletRegistrationsSupport { + + private final Collection registrations; + + ServletRegistrationsSupport(ServletContext servletContext) { + Map registrations = servletContext.getServletRegistrations(); + Collection mappings = new ArrayList<>(); + for (Map.Entry entry : registrations.entrySet()) { + if (!entry.getValue().getMappings().isEmpty()) { + for (String mapping : entry.getValue().getMappings()) { + mappings.add(new RegistrationMapping(entry.getValue(), mapping)); + } + } + } + this.registrations = mappings; + } + + Collection dispatcherServletMappings() { + Collection mappings = new ArrayList<>(); + for (RegistrationMapping registration : this.registrations) { + if (registration.isDispatcherServlet()) { + mappings.add(registration); + } + } + return mappings; + } + + Collection mappings() { + return this.registrations; + } + + record RegistrationMapping(ServletRegistration registration, String mapping) { + boolean isDispatcherServlet() { + Class dispatcherServlet = ClassUtils + .resolveClassName("org.springframework.web.servlet.DispatcherServlet", null); + try { + Class clazz = Class.forName(this.registration.getClassName()); + return dispatcherServlet.isAssignableFrom(clazz); + } + catch (ClassNotFoundException ex) { + return false; + } + } + + boolean isDefault() { + return "/".equals(this.mapping); + } + } + +}