From 50ad378a29be0e6734f440fdcb45fed18bfacaab Mon Sep 17 00:00:00 2001 From: Josh Cummings <3627351+jzheaux@users.noreply.github.com> Date: Wed, 26 Mar 2025 11:12:37 -0600 Subject: [PATCH] Polish MockHttpServletRequest Usage This commit makes so that the requestURI is set to a value that makes sense with the other properties being mocked. Issue gh-16632 --- .../cas/web/CasAuthenticationFilterTests.java | 9 ++-- ...ml4AuthenticationRequestResolverTests.java | 2 +- ...ml5AuthenticationRequestResolverTests.java | 2 +- ...RelyingPartyRegistrationResolverTests.java | 3 +- .../service/web/Saml2MetadataFilterTests.java | 52 +++++++++++-------- .../Saml2WebSsoAuthenticationFilterTests.java | 5 ++ ...ctAuthenticationProcessingFilterTests.java | 2 +- .../logout/LogoutHandlerTests.java | 9 ++-- .../ott/GenerateOneTimeTokenFilterTests.java | 1 + ...eTokenSubmitPageGeneratingFilterTests.java | 4 +- 10 files changed, 52 insertions(+), 37 deletions(-) diff --git a/cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationFilterTests.java b/cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationFilterTests.java index b4227bb9a0..74a2e2ea13 100644 --- a/cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationFilterTests.java +++ b/cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationFilterTests.java @@ -78,7 +78,7 @@ public class CasAuthenticationFilterTests { @Test public void testNormalOperation() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = new MockHttpServletRequest("POST", "/login/cas"); request.setServletPath("/login/cas"); request.addParameter("ticket", "ST-0-ER94xMJmn6pha35CQRoZ"); CasAuthenticationFilter filter = new CasAuthenticationFilter(); @@ -103,7 +103,7 @@ public class CasAuthenticationFilterTests { String url = "/login/cas"; CasAuthenticationFilter filter = new CasAuthenticationFilter(); filter.setFilterProcessesUrl(url); - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = new MockHttpServletRequest("POST", url); MockHttpServletResponse response = new MockHttpServletResponse(); request.setServletPath(url); assertThat(filter.requiresAuthentication(request, response)).isTrue(); @@ -132,10 +132,11 @@ public class CasAuthenticationFilterTests { CasAuthenticationFilter filter = new CasAuthenticationFilter(); filter.setFilterProcessesUrl(url); filter.setServiceProperties(properties); - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = new MockHttpServletRequest("POST", url); MockHttpServletResponse response = new MockHttpServletResponse(); request.setServletPath(url); assertThat(filter.requiresAuthentication(request, response)).isTrue(); + request = new MockHttpServletRequest("POST", "/other"); request.setServletPath("/other"); assertThat(filter.requiresAuthentication(request, response)).isFalse(); request.setParameter(properties.getArtifactParameter(), "value"); @@ -170,7 +171,7 @@ public class CasAuthenticationFilterTests { given(manager.authenticate(any(Authentication.class))).willReturn(authentication); ServiceProperties serviceProperties = new ServiceProperties(); serviceProperties.setAuthenticateAllArtifacts(true); - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = new MockHttpServletRequest("POST", "/authenticate"); request.setParameter("ticket", "ST-1-123"); request.setServletPath("/authenticate"); MockHttpServletResponse response = new MockHttpServletResponse(); diff --git a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolverTests.java b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolverTests.java index e834b52149..2716f0befd 100644 --- a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolverTests.java +++ b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolverTests.java @@ -102,7 +102,7 @@ public class OpenSaml4AuthenticationRequestResolverTests { } private MockHttpServletRequest givenRequest(String path) { - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", path); request.setServletPath(path); return request; } diff --git a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5AuthenticationRequestResolverTests.java b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5AuthenticationRequestResolverTests.java index 2284dfa0fd..bf5d059e1a 100644 --- a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5AuthenticationRequestResolverTests.java +++ b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5AuthenticationRequestResolverTests.java @@ -102,7 +102,7 @@ public class OpenSaml5AuthenticationRequestResolverTests { } private MockHttpServletRequest givenRequest(String path) { - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", path); request.setServletPath(path); return request; } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolverTests.java index 17e05d0f7b..aef38edede 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolverTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolverTests.java @@ -43,7 +43,8 @@ public class DefaultRelyingPartyRegistrationResolverTests { @Test public void resolveWhenRequestContainsRegistrationIdThenResolves() { - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", + "/some/path/" + this.registration.getRegistrationId()); request.setPathInfo("/some/path/" + this.registration.getRegistrationId()); RelyingPartyRegistration registration = this.resolver.convert(request); assertThat(registration).isNotNull(); diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java index 937ec7968f..471ca2a859 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java @@ -74,39 +74,39 @@ public class Saml2MetadataFilterTests { @Test public void doFilterWhenMatcherSucceedsThenResolverInvoked() throws Exception { - this.request.setPathInfo("/saml2/service-provider-metadata/registration-id"); - this.filter.doFilter(this.request, this.response, this.chain); + MockHttpServletRequest request = uri("/saml2/service-provider-metadata/registration-id"); + this.filter.doFilter(request, this.response, this.chain); verifyNoInteractions(this.chain); verify(this.repository).findByRegistrationId("registration-id"); } @Test public void doFilterWhenMatcherFailsThenProcessesFilterChain() throws Exception { - this.request.setPathInfo("/saml2/authenticate/registration-id"); - this.filter.doFilter(this.request, this.response, this.chain); - verify(this.chain).doFilter(this.request, this.response); + MockHttpServletRequest request = uri("/saml2/authenticate/registration-id"); + this.filter.doFilter(request, this.response, this.chain); + verify(this.chain).doFilter(request, this.response); } @Test public void doFilterWhenNoRelyingPartyRegistrationThenUnauthorized() throws Exception { - this.request.setPathInfo("/saml2/service-provider-metadata/invalidRegistration"); + MockHttpServletRequest request = uri("/saml2/service-provider-metadata/invalidRegistration"); given(this.repository.findByRegistrationId("invalidRegistration")).willReturn(null); - this.filter.doFilter(this.request, this.response, this.chain); + this.filter.doFilter(request, this.response, this.chain); verifyNoInteractions(this.chain); assertThat(this.response.getStatus()).isEqualTo(401); } @Test public void doFilterWhenRelyingPartyRegistrationFoundThenInvokesMetadataResolver() throws Exception { - this.request.setPathInfo("/saml2/service-provider-metadata/validRegistration"); + MockHttpServletRequest request = uri("/saml2/service-provider-metadata/validRegistration"); RelyingPartyRegistration validRegistration = TestRelyingPartyRegistrations.noCredentials() .assertingPartyDetails((party) -> party .verificationX509Credentials((c) -> c.add(TestSaml2X509Credentials.relyingPartyVerifyingCredential()))) .build(); String generatedMetadata = "test"; given(this.resolver.resolve(validRegistration)).willReturn(generatedMetadata); - this.filter = new Saml2MetadataFilter((request, registrationId) -> validRegistration, this.resolver); - this.filter.doFilter(this.request, this.response, this.chain); + this.filter = new Saml2MetadataFilter((r, registrationId) -> validRegistration, this.resolver); + this.filter.doFilter(request, this.response, this.chain); verifyNoInteractions(this.chain); assertThat(this.response.getStatus()).isEqualTo(200); assertThat(this.response.getContentAsString()).isEqualTo(generatedMetadata); @@ -128,9 +128,9 @@ public class Saml2MetadataFilterTests { @Test public void doFilterWhenCustomRequestMatcherThenUses() throws Exception { - this.request.setPathInfo("/path"); + MockHttpServletRequest request = uri("/path"); this.filter.setRequestMatcher(new AntPathRequestMatcher("/path")); - this.filter.doFilter(this.request, this.response, this.chain); + this.filter.doFilter(request, this.response, this.chain); verifyNoInteractions(this.chain); verify(this.repository).findByRegistrationId("path"); } @@ -142,11 +142,11 @@ public class Saml2MetadataFilterTests { String fileName = testMetadataFilename.replace("{registrationId}", validRegistration.getRegistrationId()); String encodedFileName = URLEncoder.encode(fileName, StandardCharsets.UTF_8.name()); String generatedMetadata = "test"; - this.request.setPathInfo("/saml2/service-provider-metadata/registration-id"); + MockHttpServletRequest request = uri("/saml2/service-provider-metadata/registration-id"); given(this.resolver.resolve(validRegistration)).willReturn(generatedMetadata); - this.filter = new Saml2MetadataFilter((request, registrationId) -> validRegistration, this.resolver); + this.filter = new Saml2MetadataFilter((r, registrationId) -> validRegistration, this.resolver); this.filter.setMetadataFilename(testMetadataFilename); - this.filter.doFilter(this.request, this.response, this.chain); + this.filter.doFilter(request, this.response, this.chain); assertThat(this.response.getHeaderValue(HttpHeaders.CONTENT_DISPOSITION)).asString() .isEqualTo("attachment; filename=\"%s\"; filename*=UTF-8''%s", fileName, encodedFileName); } @@ -160,8 +160,8 @@ public class Saml2MetadataFilterTests { (id) -> this.repository.findByRegistrationId("registration-id")); this.filter = new Saml2MetadataFilter(resolver, this.resolver); this.filter.setRequestMatcher(new AntPathRequestMatcher("/metadata")); - this.request.setPathInfo("/metadata"); - this.filter.doFilter(this.request, this.response, new MockFilterChain()); + MockHttpServletRequest request = uri("/metadata"); + this.filter.doFilter(request, this.response, new MockFilterChain()); verify(this.repository).findByRegistrationId("registration-id"); } @@ -174,8 +174,8 @@ public class Saml2MetadataFilterTests { this.filter = new Saml2MetadataFilter((id) -> this.repository.findByRegistrationId("registration-id"), this.resolver); this.filter.setRequestMatcher(new AntPathRequestMatcher("/metadata")); - this.request.setPathInfo("/metadata"); - this.filter.doFilter(this.request, this.response, new MockFilterChain()); + MockHttpServletRequest request = uri("/metadata"); + this.filter.doFilter(request, this.response, new MockFilterChain()); verify(this.repository).findByRegistrationId("registration-id"); } @@ -185,11 +185,11 @@ public class Saml2MetadataFilterTests { RelyingPartyRegistration validRegistration = TestRelyingPartyRegistrations.full().build(); String testMetadataFilename = "test-{registrationId}-metadata.xml"; String generatedMetadata = "testäöü"; - this.request.setPathInfo("/saml2/service-provider-metadata/registration-id"); + MockHttpServletRequest request = uri("/saml2/service-provider-metadata/registration-id"); given(this.resolver.resolve(validRegistration)).willReturn(generatedMetadata); this.filter = new Saml2MetadataFilter((req, id) -> validRegistration, this.resolver); this.filter.setMetadataFilename(testMetadataFilename); - this.filter.doFilter(this.request, this.response, this.chain); + this.filter.doFilter(request, this.response, this.chain); assertThat(this.response.getCharacterEncoding()).isEqualTo(StandardCharsets.UTF_8.name()); assertThat(this.response.getContentAsString(StandardCharsets.UTF_8)).isEqualTo(generatedMetadata); assertThat(this.response.getContentLength()) @@ -218,9 +218,15 @@ public class Saml2MetadataFilterTests { public void constructorWhenRelyingPartyRegistrationRepositoryThenUses() throws Exception { RelyingPartyRegistrationRepository repository = mock(RelyingPartyRegistrationRepository.class); this.filter = new Saml2MetadataFilter(repository, this.resolver); - this.request.setPathInfo("/saml2/service-provider-metadata/one"); - this.filter.doFilter(this.request, this.response, this.chain); + MockHttpServletRequest request = uri("/saml2/service-provider-metadata/one"); + this.filter.doFilter(request, this.response, this.chain); verify(repository).findByRegistrationId("one"); } + private MockHttpServletRequest uri(String uri) { + MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); + request.setPathInfo(uri); + return request; + } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2WebSsoAuthenticationFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2WebSsoAuthenticationFilterTests.java index 7b3abddf95..98cf1765df 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2WebSsoAuthenticationFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2WebSsoAuthenticationFilterTests.java @@ -68,6 +68,7 @@ public class Saml2WebSsoAuthenticationFilterTests { @BeforeEach public void setup() { this.filter = new Saml2WebSsoAuthenticationFilter(this.repository); + this.request.setRequestURI("/login/saml2/sso/idp-registration-id"); this.request.setPathInfo("/login/saml2/sso/idp-registration-id"); this.request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "xml-data-goes-here"); } @@ -99,6 +100,7 @@ public class Saml2WebSsoAuthenticationFilterTests { @Test public void requiresAuthenticationWhenCustomProcessingUrlThenReturnsTrue() { this.filter = new Saml2WebSsoAuthenticationFilter(this.repository, "/some/other/path/{registrationId}"); + this.request.setRequestURI("/some/other/path/idp-registration-id"); this.request.setPathInfo("/some/other/path/idp-registration-id"); this.request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "xml-data-goes-here"); assertThat(this.filter.requiresAuthentication(this.request, this.response)).isTrue(); @@ -108,6 +110,7 @@ public class Saml2WebSsoAuthenticationFilterTests { public void attemptAuthenticationWhenRegistrationIdDoesNotExistThenThrowsException() { given(this.repository.findByRegistrationId("non-existent-id")).willReturn(null); this.filter = new Saml2WebSsoAuthenticationFilter(this.repository, "/some/other/path/{registrationId}"); + this.request.setRequestURI("/some/other/path/non-existent-id"); this.request.setPathInfo("/some/other/path/non-existent-id"); this.request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "response"); assertThatExceptionOfType(Saml2AuthenticationException.class) @@ -123,6 +126,7 @@ public class Saml2WebSsoAuthenticationFilterTests { given(authenticationConverter.convert(this.request)).willReturn(TestSaml2AuthenticationTokens.token()); this.filter = new Saml2WebSsoAuthenticationFilter(authenticationConverter, "/some/other/path/{registrationId}"); this.filter.setAuthenticationManager((authentication) -> null); + this.request.setRequestURI("/some/other/path/idp-registration-id"); this.request.setPathInfo("/some/other/path/idp-registration-id"); this.filter.setAuthenticationRequestRepository(authenticationRequestRepository); this.filter.attemptAuthentication(this.request, this.response); @@ -201,6 +205,7 @@ public class Saml2WebSsoAuthenticationFilterTests { Saml2AuthenticationTokenConverter authenticationConverter = new Saml2AuthenticationTokenConverter(resolver); this.filter = new Saml2WebSsoAuthenticationFilter(authenticationConverter, loginProcessingUrl); this.filter.setAuthenticationManager(this.authenticationManager); + this.request.setRequestURI("/registration-id/login/saml2/sso"); this.request.setPathInfo("/registration-id/login/saml2/sso"); this.request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "response"); this.filter.doFilter(this.request, this.response, new MockFilterChain()); diff --git a/web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java index 3d52928047..00e4de0614 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java @@ -100,7 +100,7 @@ public class AbstractAuthenticationProcessingFilterTests { @Test public void testDefaultProcessesFilterUrlMatchesWithPathParameter() { - MockHttpServletRequest request = createMockAuthenticationRequest(); + MockHttpServletRequest request = new MockHttpServletRequest("POST", "/login;jsessionid=I8MIONOSTHOR"); MockHttpServletResponse response = new MockHttpServletResponse(); MockAuthenticationFilter filter = new MockAuthenticationFilter(); filter.setFilterProcessesUrl("/login"); diff --git a/web/src/test/java/org/springframework/security/web/authentication/logout/LogoutHandlerTests.java b/web/src/test/java/org/springframework/security/web/authentication/logout/LogoutHandlerTests.java index e86c7223fd..6039fd27a8 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/logout/LogoutHandlerTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/logout/LogoutHandlerTests.java @@ -39,9 +39,9 @@ public class LogoutHandlerTests { @Test public void testRequiresLogoutUrlWorksWithPathParams() { - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = new MockHttpServletRequest("POST", "/context/logout;someparam=blah"); MockHttpServletResponse response = new MockHttpServletResponse(); - request.setRequestURI("/context/logout;someparam=blah?param=blah"); + request.setContextPath("/context"); request.setServletPath("/logout;someparam=blah"); request.setQueryString("otherparam=blah"); DefaultHttpFirewall fw = new DefaultHttpFirewall(); @@ -50,12 +50,11 @@ public class LogoutHandlerTests { @Test public void testRequiresLogoutUrlWorksWithQueryParams() { - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/context/logout"); request.setContextPath("/context"); MockHttpServletResponse response = new MockHttpServletResponse(); request.setServletPath("/logout"); - request.setRequestURI("/context/logout?param=blah"); - request.setQueryString("otherparam=blah"); + request.setQueryString("param=blah"); assertThat(this.filter.requiresLogout(request, response)).isTrue(); } diff --git a/web/src/test/java/org/springframework/security/web/authentication/ott/GenerateOneTimeTokenFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/ott/GenerateOneTimeTokenFilterTests.java index f3cdb2fd51..4c7085ab70 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/ott/GenerateOneTimeTokenFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/ott/GenerateOneTimeTokenFilterTests.java @@ -65,6 +65,7 @@ public class GenerateOneTimeTokenFilterTests { void setup() { this.request.setMethod("POST"); this.request.setServletPath("/ott/generate"); + this.request.setRequestURI("/ott/generate"); } @Test diff --git a/web/src/test/java/org/springframework/security/web/authentication/ui/DefaultOneTimeTokenSubmitPageGeneratingFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/ui/DefaultOneTimeTokenSubmitPageGeneratingFilterTests.java index 0aab14a8fc..ad38fa6f7c 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/ui/DefaultOneTimeTokenSubmitPageGeneratingFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/ui/DefaultOneTimeTokenSubmitPageGeneratingFilterTests.java @@ -37,7 +37,7 @@ class DefaultOneTimeTokenSubmitPageGeneratingFilterTests { DefaultOneTimeTokenSubmitPageGeneratingFilter filter = new DefaultOneTimeTokenSubmitPageGeneratingFilter(); - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/login/ott"); MockHttpServletResponse response = new MockHttpServletResponse(); @@ -47,6 +47,7 @@ class DefaultOneTimeTokenSubmitPageGeneratingFilterTests { void setup() { this.request.setMethod("GET"); this.request.setServletPath("/login/ott"); + this.request.setRequestURI("/login/ott"); } @Test @@ -80,6 +81,7 @@ class DefaultOneTimeTokenSubmitPageGeneratingFilterTests { @Test void setContextThenGenerates() throws Exception { this.request.setContextPath("/context"); + this.request.setRequestURI("/context/login/ott"); this.filter.setLoginProcessingUrl("/login/another"); this.filter.doFilterInternal(this.request, this.response, this.filterChain); String response = this.response.getContentAsString();