From 8827b2e5646cd0e0b0e81248c282df3093d8443c Mon Sep 17 00:00:00 2001 From: Josh Cummings <3627351+jzheaux@users.noreply.github.com> Date: Fri, 10 Jan 2025 14:58:18 -0700 Subject: [PATCH] Polish Using Request ServletContext Issue gh-14418 --- .../web/AbstractRequestMatcherRegistry.java | 16 ++++------------ .../web/AbstractRequestMatcherRegistryTests.java | 12 +++++------- 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java b/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java index e29a4d599a..4f849f86fb 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java @@ -447,18 +447,12 @@ public abstract class AbstractRequestMatcherRegistry { static class DispatcherServletRequestMatcher implements RequestMatcher { - private final ServletContext servletContext; - - DispatcherServletRequestMatcher(ServletContext servletContext) { - this.servletContext = servletContext; - } - @Override public boolean matches(HttpServletRequest request) { String name = request.getHttpServletMapping().getServletName(); - ServletRegistration registration = this.servletContext.getServletRegistration(name); + ServletRegistration registration = request.getServletContext().getServletRegistration(name); Assert.notNull(registration, - () -> computeErrorMessage(this.servletContext.getServletRegistrations().values())); + () -> computeErrorMessage(request.getServletContext().getServletRegistrations().values())); try { Class clazz = Class.forName(registration.getClassName()); return DispatcherServlet.class.isAssignableFrom(clazz); @@ -478,10 +472,8 @@ public abstract class AbstractRequestMatcherRegistry { private final RequestMatcher dispatcherServlet; - DispatcherServletDelegatingRequestMatcher(AntPathRequestMatcher ant, MvcRequestMatcher mvc, - ServletContext servletContext) { - this(ant, mvc, new OrRequestMatcher(new MockMvcRequestMatcher(), - new DispatcherServletRequestMatcher(servletContext))); + DispatcherServletDelegatingRequestMatcher(AntPathRequestMatcher ant, MvcRequestMatcher mvc) { + this(ant, mvc, new OrRequestMatcher(new MockMvcRequestMatcher(), new DispatcherServletRequestMatcher())); } DispatcherServletDelegatingRequestMatcher(AntPathRequestMatcher ant, MvcRequestMatcher mvc, diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java index 8561390515..1fb6e580b1 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java @@ -318,7 +318,7 @@ public class AbstractRequestMatcherRegistryTests { List requestMatchers = this.matcherRegistry.requestMatchers("/services/*"); assertThat(requestMatchers).hasSize(1); assertThat(requestMatchers.get(0)).isInstanceOf(DispatcherServletDelegatingRequestMatcher.class); - MockHttpServletRequest request = new MockHttpServletRequest("GET", "/services/endpoint"); + MockHttpServletRequest request = new MockHttpServletRequest(servletContext, "GET", "/services/endpoint"); request.setHttpServletMapping(TestMockHttpServletMappings.defaultMapping()); assertThat(requestMatchers.get(0).matcher(request).isMatch()).isTrue(); request.setHttpServletMapping(TestMockHttpServletMappings.path(request, "/services")); @@ -334,9 +334,8 @@ public class AbstractRequestMatcherRegistryTests { servletContext.addServlet("path", Servlet.class).addMapping("/services/*"); MvcRequestMatcher mvc = mock(MvcRequestMatcher.class); AntPathRequestMatcher ant = mock(AntPathRequestMatcher.class); - DispatcherServletDelegatingRequestMatcher requestMatcher = new DispatcherServletDelegatingRequestMatcher(ant, - mvc, servletContext); - MockHttpServletRequest request = new MockHttpServletRequest("GET", "/services/endpoint"); + RequestMatcher requestMatcher = new DispatcherServletDelegatingRequestMatcher(ant, mvc); + MockHttpServletRequest request = new MockHttpServletRequest(servletContext, "GET", "/services/endpoint"); request.setHttpServletMapping(TestMockHttpServletMappings.defaultMapping()); assertThat(requestMatcher.matches(request)).isFalse(); verify(mvc).matches(request); @@ -354,9 +353,8 @@ public class AbstractRequestMatcherRegistryTests { servletContext.addServlet("path", Servlet.class).addMapping("/services/*"); MvcRequestMatcher mvc = mock(MvcRequestMatcher.class); AntPathRequestMatcher ant = mock(AntPathRequestMatcher.class); - DispatcherServletDelegatingRequestMatcher requestMatcher = new DispatcherServletDelegatingRequestMatcher(ant, - mvc, servletContext); - MockHttpServletRequest request = new MockHttpServletRequest("GET", "/services/endpoint"); + RequestMatcher requestMatcher = new DispatcherServletDelegatingRequestMatcher(ant, mvc); + MockHttpServletRequest request = new MockHttpServletRequest(servletContext, "GET", "/services/endpoint"); assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> requestMatcher.matcher(request)); }