From 531c5cafdc9a70bf63198b4bdadcd13561faf8c7 Mon Sep 17 00:00:00 2001 From: Josh Cummings <3627351+jzheaux@users.noreply.github.com> Date: Mon, 30 Jun 2025 12:24:16 -0600 Subject: [PATCH] Standardize Mocked Request Paths Historically, Spring Security tests have set the servlet path to indicate the path of a MockHttpServletRequest. This was needed for AntPath and MvcRequestMatcher to correctly match the specified request path. This can leave MockHttpServletRequest in an inconsistent state since requestURI is null while servletPath has a value. For example, PathPatternRequestMatcher does not use the servlet path. For tests to continue working both before and after the migration from AntPath/MvcRequestMatcher to PathPatternRequestMatcher, the mock requests should have a consistent representation of path in getRequestURI and getServletPath. This commit updates classes to use TestMockHttpServletRequests, which ensures that the given path is applied to the servletPath and requestURI, while also overriding with contextPath, servletPath, and pathInfo when necessary. --- cas/spring-security-cas.gradle | 1 + .../cas/web/CasAuthenticationFilterTests.java | 36 ++-- .../config/FilterChainProxyConfigTests.java | 6 +- .../configurers/AuthorizeRequestsTests.java | 7 - .../configurers/HttpSecurityLogoutTests.java | 11 +- .../HttpSecurityRequestMatchersTests.java | 59 +++--- ...ttpSecuritySecurityMatchersNoMvcTests.java | 8 +- ...ionManagementConfigurerServlet31Tests.java | 23 +-- .../client/OAuth2LoginConfigurerTests.java | 52 ++---- ...reshedEventListenerConfigurationTests.java | 4 +- .../OidcUserRefreshedEventListenerTests.java | 4 +- .../saml2/Saml2LoginConfigurerTests.java | 5 +- .../saml2/Saml2LogoutConfigurerTests.java | 4 +- ...tadataSourceBeanDefinitionParserTests.java | 4 +- .../Saml2LogoutBeanDefinitionParserTests.java | 3 +- ...SessionManagementConfigServlet31Tests.java | 34 ++-- .../CustomHttpSecurityConfigurerTests.java | 10 +- ...egistrationsBeanDefinitionParserTests.java | 5 +- .../web/AuthorizeRequestsDslTests.kt | 31 +--- .../annotation/web/RequiresChannelDslTests.kt | 10 +- etc/checkstyle/checkstyle.xml | 1 + .../spring-security-itest-context.gradle | 2 +- ...amespaceWithMultipleInterceptorsTests.java | 9 +- .../spring-security-oauth2-client.gradle | 1 + ...uth2AuthorizationRequestResolverTests.java | 106 ++++------- ...uth2AuthorizationCodeGrantFilterTests.java | 10 +- ...thorizationRequestRedirectFilterTests.java | 37 ++-- .../OAuth2LoginAuthenticationFilterTests.java | 95 ++++------ ...ing-security-saml2-service-provider.gradle | 1 + ...aml4AuthenticationTokenConverterTests.java | 9 +- ...SamlAuthenticationTokenConverterTests.java | 9 +- ...ml4AuthenticationRequestResolverTests.java | 4 +- ...questValidatorParametersResolverTests.java | 9 +- ...questValidatorParametersResolverTests.java | 9 +- ...aml5AuthenticationTokenConverterTests.java | 9 +- ...ml5AuthenticationRequestResolverTests.java | 5 +- ...questValidatorParametersResolverTests.java | 9 +- ...tMatcherMetadataResponseResolverTests.java | 5 +- .../logout/Saml2LogoutRequestFilterTests.java | 34 ++-- .../Saml2LogoutResponseFilterTests.java | 27 ++- ...rtyInitiatedLogoutSuccessHandlerTests.java | 7 +- .../security/web/FilterChainProxyTests.java | 4 +- .../security/web/FilterInvocationTests.java | 36 ++-- .../RequestMatcherRedirectFilterTests.java | 8 +- .../ExceptionTranslationFilterTests.java | 44 ++--- .../channel/ChannelProcessingFilterTests.java | 10 +- .../InsecureChannelProcessorTests.java | 22 +-- .../channel/SecureChannelProcessorTests.java | 22 +-- ...leEvaluationContextPostProcessorTests.java | 4 +- ...InvocationSecurityMetadataSourceTests.java | 10 +- .../FilterSecurityInterceptorTests.java | 4 +- ...ctAuthenticationProcessingFilterTests.java | 36 ++-- ...LoginUrlAuthenticationEntryPointTests.java | 61 ++----- ...ingAuthenticationManagerResolverTests.java | 4 +- ...namePasswordAuthenticationFilterTests.java | 9 +- .../logout/LogoutHandlerTests.java | 16 +- .../ott/GenerateOneTimeTokenFilterTests.java | 8 +- ...eTokenSubmitPageGeneratingFilterTests.java | 14 +- .../www/BasicAuthenticationFilterTests.java | 56 ++---- .../www/DigestAuthenticationFilterTests.java | 4 +- .../security/web/debug/DebugFilterTests.java | 6 +- .../firewall/DefaultHttpFirewallTests.java | 4 +- .../web/firewall/StrictHttpFirewallTests.java | 4 +- .../servlet/TestMockHttpServletRequests.java | 169 ++++++++++++++++++ .../matcher/RegexRequestMatcherTests.java | 4 +- 65 files changed, 553 insertions(+), 721 deletions(-) create mode 100644 web/src/test/java/org/springframework/security/web/servlet/TestMockHttpServletRequests.java diff --git a/cas/spring-security-cas.gradle b/cas/spring-security-cas.gradle index cc5c13f604..fd4a614fa5 100644 --- a/cas/spring-security-cas.gradle +++ b/cas/spring-security-cas.gradle @@ -14,6 +14,7 @@ dependencies { provided 'jakarta.servlet:jakarta.servlet-api' + testImplementation project(path : ':spring-security-web', configuration : 'tests') testImplementation "org.assertj:assertj-core" testImplementation "org.junit.jupiter:junit-jupiter-api" testImplementation "org.junit.jupiter:junit-jupiter-params" diff --git a/cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationFilterTests.java b/cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationFilterTests.java index 296043527e..423c99cfe5 100644 --- a/cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationFilterTests.java +++ b/cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationFilterTests.java @@ -55,6 +55,8 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post; /** * Tests {@link CasAuthenticationFilter}. @@ -79,9 +81,7 @@ public class CasAuthenticationFilterTests { @Test public void testNormalOperation() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/login/cas"); - request.setServletPath("/login/cas"); - request.addParameter("ticket", "ST-0-ER94xMJmn6pha35CQRoZ"); + MockHttpServletRequest request = post("/login/cas").param("ticket", "ST-0-ER94xMJmn6pha35CQRoZ").build(); CasAuthenticationFilter filter = new CasAuthenticationFilter(); filter.setAuthenticationManager((a) -> a); assertThat(filter.requiresAuthentication(request, new MockHttpServletResponse())).isTrue(); @@ -104,24 +104,22 @@ public class CasAuthenticationFilterTests { String url = "/login/cas"; CasAuthenticationFilter filter = new CasAuthenticationFilter(); filter.setFilterProcessesUrl(url); - MockHttpServletRequest request = new MockHttpServletRequest("POST", url); + MockHttpServletRequest request = post(url).build(); MockHttpServletResponse response = new MockHttpServletResponse(); - request.setServletPath(url); assertThat(filter.requiresAuthentication(request, response)).isTrue(); } @Test public void testRequiresAuthenticationProxyRequest() { CasAuthenticationFilter filter = new CasAuthenticationFilter(); - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/pgtCallback").build(); MockHttpServletResponse response = new MockHttpServletResponse(); - request.setServletPath("/pgtCallback"); assertThat(filter.requiresAuthentication(request, response)).isFalse(); filter.setProxyReceptorUrl(request.getServletPath()); assertThat(filter.requiresAuthentication(request, response)).isFalse(); filter.setProxyGrantingTicketStorage(mock(ProxyGrantingTicketStorage.class)); assertThat(filter.requiresAuthentication(request, response)).isTrue(); - request.setServletPath("/other"); + request = get("/other").build(); assertThat(filter.requiresAuthentication(request, response)).isFalse(); } @@ -133,12 +131,10 @@ public class CasAuthenticationFilterTests { CasAuthenticationFilter filter = new CasAuthenticationFilter(); filter.setFilterProcessesUrl(url); filter.setServiceProperties(properties); - MockHttpServletRequest request = new MockHttpServletRequest("POST", url); + MockHttpServletRequest request = post(url).build(); MockHttpServletResponse response = new MockHttpServletResponse(); - request.setServletPath(url); assertThat(filter.requiresAuthentication(request, response)).isTrue(); - request = new MockHttpServletRequest("POST", "/other"); - request.setServletPath("/other"); + request = post("/other").build(); assertThat(filter.requiresAuthentication(request, response)).isFalse(); request.setParameter(properties.getArtifactParameter(), "value"); assertThat(filter.requiresAuthentication(request, response)).isTrue(); @@ -156,9 +152,8 @@ public class CasAuthenticationFilterTests { @Test public void testAuthenticateProxyUrl() throws Exception { CasAuthenticationFilter filter = new CasAuthenticationFilter(); - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/pgtCallback").build(); MockHttpServletResponse response = new MockHttpServletResponse(); - request.setServletPath("/pgtCallback"); filter.setProxyGrantingTicketStorage(mock(ProxyGrantingTicketStorage.class)); filter.setProxyReceptorUrl(request.getServletPath()); assertThat(filter.attemptAuthentication(request, response)).isNull(); @@ -172,9 +167,7 @@ public class CasAuthenticationFilterTests { given(manager.authenticate(any(Authentication.class))).willReturn(authentication); ServiceProperties serviceProperties = new ServiceProperties(); serviceProperties.setAuthenticateAllArtifacts(true); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/authenticate"); - request.setParameter("ticket", "ST-1-123"); - request.setServletPath("/authenticate"); + MockHttpServletRequest request = post("/authenticate").param("ticket", "ST-1-123").build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); CasAuthenticationFilter filter = new CasAuthenticationFilter(); @@ -200,10 +193,9 @@ public class CasAuthenticationFilterTests { @Test public void testChainNotInvokedForProxyReceptor() throws Exception { CasAuthenticationFilter filter = new CasAuthenticationFilter(); - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/pgtCallback").build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); - request.setServletPath("/pgtCallback"); filter.setProxyGrantingTicketStorage(mock(ProxyGrantingTicketStorage.class)); filter.setProxyReceptorUrl(request.getServletPath()); filter.doFilter(request, response, chain); @@ -271,16 +263,14 @@ public class CasAuthenticationFilterTests { @Test public void requiresAuthenticationWhenProxyRequestMatcherThenMatches() { CasAuthenticationFilter filter = new CasAuthenticationFilter(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", "/pgtCallback"); + MockHttpServletRequest request = get("/pgtCallback").build(); MockHttpServletResponse response = new MockHttpServletResponse(); - request.setServletPath("/pgtCallback"); assertThat(filter.requiresAuthentication(request, response)).isFalse(); filter.setProxyReceptorMatcher(PathPatternRequestMatcher.withDefaults().matcher(request.getServletPath())); assertThat(filter.requiresAuthentication(request, response)).isFalse(); filter.setProxyGrantingTicketStorage(mock(ProxyGrantingTicketStorage.class)); assertThat(filter.requiresAuthentication(request, response)).isTrue(); - request.setRequestURI("/other"); - request.setServletPath("/other"); + request = get("/other").build(); assertThat(filter.requiresAuthentication(request, response)).isFalse(); } diff --git a/config/src/test/java/org/springframework/security/config/FilterChainProxyConfigTests.java b/config/src/test/java/org/springframework/security/config/FilterChainProxyConfigTests.java index ba9a6fa5ce..71d879e399 100644 --- a/config/src/test/java/org/springframework/security/config/FilterChainProxyConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/FilterChainProxyConfigTests.java @@ -45,6 +45,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests {@link FilterChainProxy}. @@ -144,13 +145,12 @@ public class FilterChainProxyConfigTests { } private void doNormalOperation(FilterChainProxy filterChainProxy) throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); - request.setServletPath("/foo/secure/super/somefile.html"); + MockHttpServletRequest request = get("/foo/secure/super/somefile.html").build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); filterChainProxy.doFilter(request, response, chain); verify(chain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); - request.setServletPath("/a/path/which/doesnt/match/any/filter.html"); + request = get("/a/path/which/doesnt/match/any/filter.html").build(); chain = mock(FilterChain.class); filterChainProxy.doFilter(request, response, chain); verify(chain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AuthorizeRequestsTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AuthorizeRequestsTests.java index 03a8f48e44..e2f2fb5c73 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AuthorizeRequestsTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AuthorizeRequestsTests.java @@ -77,7 +77,6 @@ public class AuthorizeRequestsTests { public void setup() { this.servletContext = spy(MockServletContext.mvc()); this.request = new MockHttpServletRequest(this.servletContext, "GET", ""); - this.request.setMethod("GET"); this.response = new MockHttpServletResponse(); this.chain = new MockFilterChain(); } @@ -110,12 +109,10 @@ public class AuthorizeRequestsTests { @Test public void antMatchersPathVariables() throws Exception { loadConfig(AntPatchersPathVariables.class); - this.request.setServletPath("/user/user"); this.request.setRequestURI("/user/user"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); this.setup(); - this.request.setServletPath("/user/deny"); this.request.setRequestURI("/user/deny"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); @@ -125,12 +122,10 @@ public class AuthorizeRequestsTests { @Test public void antMatchersPathVariablesCaseInsensitive() throws Exception { loadConfig(AntPatchersPathVariables.class); - this.request.setServletPath("/USER/user"); this.request.setRequestURI("/USER/user"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); this.setup(); - this.request.setServletPath("/USER/deny"); this.request.setRequestURI("/USER/deny"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); @@ -140,12 +135,10 @@ public class AuthorizeRequestsTests { @Test public void antMatchersPathVariablesCaseInsensitiveCamelCaseVariables() throws Exception { loadConfig(AntMatchersPathVariablesCamelCaseVariables.class); - this.request.setServletPath("/USER/user"); this.request.setRequestURI("/USER/user"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); this.setup(); - this.request.setServletPath("/USER/deny"); this.request.setRequestURI("/USER/deny"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityLogoutTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityLogoutTests.java index b82c2a57a3..98340ec271 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityLogoutTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityLogoutTests.java @@ -39,6 +39,7 @@ import org.springframework.security.web.context.HttpSessionSecurityContextReposi import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post; /** * @author Rob Winch @@ -48,8 +49,6 @@ public class HttpSecurityLogoutTests { AnnotationConfigWebApplicationContext context; - MockHttpServletRequest request; - MockHttpServletResponse response; MockFilterChain chain; @@ -59,7 +58,6 @@ public class HttpSecurityLogoutTests { @BeforeEach public void setup() { - this.request = new MockHttpServletRequest("GET", ""); this.response = new MockHttpServletResponse(); this.chain = new MockFilterChain(); } @@ -77,11 +75,10 @@ public class HttpSecurityLogoutTests { loadConfig(ClearAuthenticationFalseConfig.class); SecurityContext currentContext = SecurityContextHolder.createEmptyContext(); currentContext.setAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER")); - this.request.getSession() + MockHttpServletRequest request = post("/logout").build(); + request.getSession() .setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, currentContext); - this.request.setMethod("POST"); - this.request.setServletPath("/logout"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(currentContext.getAuthentication()).isNotNull(); } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityRequestMatchersTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityRequestMatchersTests.java index c970ae7509..76bc331be1 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityRequestMatchersTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityRequestMatchersTests.java @@ -45,6 +45,7 @@ import org.springframework.web.servlet.config.annotation.EnableWebMvc; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.security.config.Customizer.withDefaults; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * @author Rob Winch @@ -54,8 +55,6 @@ public class HttpSecurityRequestMatchersTests { AnnotationConfigWebApplicationContext context; - MockHttpServletRequest request; - MockHttpServletResponse response; MockFilterChain chain; @@ -65,8 +64,6 @@ public class HttpSecurityRequestMatchersTests { @BeforeEach public void setup() { - this.request = new MockHttpServletRequest("GET", ""); - this.request.setMethod("GET"); this.response = new MockHttpServletResponse(); this.chain = new MockFilterChain(); } @@ -87,70 +84,64 @@ public class HttpSecurityRequestMatchersTests { @Test public void requestMatchersMvcMatcherServletPath() throws Exception { loadConfig(RequestMatchersMvcMatcherServeltPathConfig.class); - this.request.setServletPath("/spring"); - this.request.setRequestURI("/spring/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + MockHttpServletRequest request = get().requestUri(null, "/spring", "/path").build(); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); setup(); - this.request.setServletPath(""); - this.request.setRequestURI("/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + request = get().requestUri(null, "", "/path").build(); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); setup(); - this.request.setServletPath("/other"); - this.request.setRequestURI("/other/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + request = get().requestUri(null, "/other", "/path").build(); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); } @Test public void requestMatcherWhensMvcMatcherServletPathInLambdaThenPathIsSecured() throws Exception { loadConfig(RequestMatchersMvcMatcherServletPathInLambdaConfig.class); - this.request.setServletPath("/spring"); - this.request.setRequestURI("/spring/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + MockHttpServletRequest request = get().requestUri(null, "/spring", "/path").build(); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); setup(); - this.request.setServletPath(""); - this.request.setRequestURI("/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + request = get().requestUri(null, "", "/path").build(); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); setup(); - this.request.setServletPath("/other"); - this.request.setRequestURI("/other/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + request = get().requestUri(null, "/other", "/path").build(); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); } @Test public void requestMatcherWhenMultiMvcMatcherInLambdaThenAllPathsAreDenied() throws Exception { loadConfig(MultiMvcMatcherInLambdaConfig.class); - this.request.setRequestURI("/test-1"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + MockHttpServletRequest request = get("/test-1").build(); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); setup(); - this.request.setRequestURI("/test-2"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + request = get("/test-2").build(); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); setup(); - this.request.setRequestURI("/test-3"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + request = get("/test-3").build(); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); } @Test public void requestMatcherWhenMultiMvcMatcherThenAllPathsAreDenied() throws Exception { loadConfig(MultiMvcMatcherConfig.class); - this.request.setRequestURI("/test-1"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + MockHttpServletRequest request = get("/test-1").build(); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); setup(); - this.request.setRequestURI("/test-2"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + request = get("/test-2").build(); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); setup(); - this.request.setRequestURI("/test-3"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + request = get("/test-3").build(); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecuritySecurityMatchersNoMvcTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecuritySecurityMatchersNoMvcTests.java index a1625bce35..34b3d3a379 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecuritySecurityMatchersNoMvcTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecuritySecurityMatchersNoMvcTests.java @@ -67,7 +67,7 @@ public class HttpSecuritySecurityMatchersNoMvcTests { @BeforeEach public void setup() throws Exception { - this.request = new MockHttpServletRequest("GET", ""); + this.request = new MockHttpServletRequest(); this.request.setMethod("GET"); this.response = new MockHttpServletResponse(); this.chain = new MockFilterChain(); @@ -83,15 +83,15 @@ public class HttpSecuritySecurityMatchersNoMvcTests { @Test public void securityMatcherWhenNoMvcThenAntMatcher() throws Exception { loadConfig(SecurityMatcherNoMvcConfig.class); - this.request.setServletPath("/path"); + this.request.setRequestURI("/path"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); setup(); - this.request.setServletPath("/path.html"); + this.request.setRequestURI("/path.html"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); setup(); - this.request.setServletPath("/path/"); + this.request.setRequestURI("/path/"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); List requestMatchers = this.springSecurityFilterChain.getFilterChains() .stream() diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java index 65d449efe9..98998e31e8 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java @@ -30,14 +30,10 @@ import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; -import org.springframework.security.core.Authentication; -import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.core.userdetails.PasswordEncodedUser; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.provisioning.InMemoryUserDetailsManager; import org.springframework.security.web.SecurityFilterChain; -import org.springframework.security.web.context.HttpRequestResponseHolder; -import org.springframework.security.web.context.HttpSessionSecurityContextRepository; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.CsrfTokenRequestHandler; import org.springframework.security.web.csrf.DeferredCsrfToken; @@ -46,14 +42,13 @@ import org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.security.config.Customizer.withDefaults; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post; /** * @author Rob Winch */ public class SessionManagementConfigurerServlet31Tests { - MockHttpServletRequest request; - MockHttpServletResponse response; MockFilterChain chain; @@ -64,7 +59,6 @@ public class SessionManagementConfigurerServlet31Tests { @BeforeEach public void setup() { - this.request = new MockHttpServletRequest("GET", ""); this.response = new MockHttpServletResponse(); this.chain = new MockFilterChain(); } @@ -78,13 +72,9 @@ public class SessionManagementConfigurerServlet31Tests { @Test public void changeSessionIdThenPreserveParameters() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest("GET", "/login"); + MockHttpServletRequest request = post("/login").param("username", "user").param("password", "password").build(); String id = request.getSession().getId(); request.getSession(); - request.setServletPath("/login"); - request.setMethod("POST"); - request.setParameter("username", "user"); - request.setParameter("password", "password"); HttpSessionCsrfTokenRepository repository = new HttpSessionCsrfTokenRepository(); CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler(); DeferredCsrfToken deferredCsrfToken = repository.loadDeferredToken(request, this.response); @@ -106,15 +96,6 @@ public class SessionManagementConfigurerServlet31Tests { this.springSecurityFilterChain = this.context.getBean("springSecurityFilterChain", Filter.class); } - private void login(Authentication auth) { - HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository(); - HttpRequestResponseHolder requestResponseHolder = new HttpRequestResponseHolder(this.request, this.response); - repo.loadContext(requestResponseHolder); - SecurityContextImpl securityContextImpl = new SecurityContextImpl(); - securityContextImpl.setAuthentication(auth); - repo.saveContext(securityContextImpl, requestResponseHolder.getRequest(), requestResponseHolder.getResponse()); - } - @Configuration @EnableWebSecurity static class SessionManagementDefaultSessionFixationServlet31Config { diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java index f361eacfd5..d24fc4f723 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java @@ -107,6 +107,7 @@ import org.springframework.security.web.authentication.HttpStatusEntryPoint; import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.HttpSessionSecurityContextRepository; import org.springframework.security.web.context.SecurityContextRepository; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import org.springframework.security.web.session.HttpSessionDestroyedEvent; import org.springframework.security.web.util.matcher.RequestHeaderRequestMatcher; import org.springframework.test.util.ReflectionTestUtils; @@ -127,6 +128,7 @@ import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.config.annotation.SecurityContextChangedListenerArgumentMatchers.setAuthentication; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; @@ -185,8 +187,7 @@ public class OAuth2LoginConfigurerTests { @BeforeEach public void setup() { - this.request = new MockHttpServletRequest("GET", "/login/oauth2/code/google"); - this.request.setServletPath("/login/oauth2/code/google"); + this.request = TestMockHttpServletRequests.get("/login/oauth2/code/google").build(); this.response = new MockHttpServletResponse(); this.filterChain = new MockFilterChain(); } @@ -347,7 +348,7 @@ public class OAuth2LoginConfigurerTests { loadConfig(OAuth2LoginConfigLoginProcessingUrl.class); // setup authorization request OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest(); - this.request.setServletPath("/login/oauth2/google"); + this.request.setRequestURI("/login/oauth2/google"); this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, this.request, this.response); // setup authentication parameters this.request.setParameter("code", "code123"); @@ -381,8 +382,7 @@ public class OAuth2LoginConfigurerTests { // @formatter:on given(resolver.resolve(any())).willReturn(result); String requestUri = "/oauth2/authorization/google"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = TestMockHttpServletRequests.get(requestUri).build(); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getRedirectedUrl()).isEqualTo( "https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1"); @@ -394,8 +394,7 @@ public class OAuth2LoginConfigurerTests { // @formatter:off // @formatter:on String requestUri = "/oauth2/authorization/google"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = TestMockHttpServletRequests.get(requestUri).build(); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getRedirectedUrl()).isEqualTo( "https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1"); @@ -418,8 +417,7 @@ public class OAuth2LoginConfigurerTests { // @formatter:on given(resolver.resolve(any())).willReturn(result); String requestUri = "/oauth2/authorization/google"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = TestMockHttpServletRequests.get(requestUri).build(); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getRedirectedUrl()).isEqualTo( "https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1"); @@ -432,8 +430,7 @@ public class OAuth2LoginConfigurerTests { RedirectStrategy redirectStrategy = this.context .getBean(OAuth2LoginConfigCustomAuthorizationRedirectStrategy.class).redirectStrategy; String requestUri = "/oauth2/authorization/google"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = get(requestUri).build(); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); then(redirectStrategy).should().sendRedirect(any(), any(), anyString()); } @@ -445,8 +442,7 @@ public class OAuth2LoginConfigurerTests { RedirectStrategy redirectStrategy = this.context .getBean(OAuth2LoginConfigCustomAuthorizationRedirectStrategyInLambda.class).redirectStrategy; String requestUri = "/oauth2/authorization/google"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = get(requestUri).build(); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); then(redirectStrategy).should().sendRedirect(any(), any(), anyString()); } @@ -456,8 +452,7 @@ public class OAuth2LoginConfigurerTests { public void oauth2LoginWithOneClientConfiguredThenRedirectForAuthorization() throws Exception { loadConfig(OAuth2LoginConfig.class); String requestUri = "/"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = get(requestUri).build(); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getRedirectedUrl()).matches("http://localhost/oauth2/authorization/google"); } @@ -467,8 +462,7 @@ public class OAuth2LoginConfigurerTests { public void oauth2LoginWithOneClientConfiguredAndFormLoginThenRedirectDefaultLoginPage() throws Exception { loadConfig(OAuth2LoginConfigFormLogin.class); String requestUri = "/"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = get(requestUri).build(); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getRedirectedUrl()).matches("http://localhost/login"); } @@ -479,8 +473,7 @@ public class OAuth2LoginConfigurerTests { throws Exception { loadConfig(OAuth2LoginConfig.class); String requestUri = "/favicon.ico"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = get(requestUri).build(); this.request.addHeader(HttpHeaders.ACCEPT, new MediaType("image", "*").toString()); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getRedirectedUrl()).matches("http://localhost/login"); @@ -491,8 +484,7 @@ public class OAuth2LoginConfigurerTests { public void oauth2LoginWithMultipleClientsConfiguredThenRedirectDefaultLoginPage() throws Exception { loadConfig(OAuth2LoginConfigMultipleClients.class); String requestUri = "/"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = get(requestUri).build(); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getRedirectedUrl()).matches("http://localhost/login"); } @@ -503,8 +495,7 @@ public class OAuth2LoginConfigurerTests { throws Exception { loadConfig(OAuth2LoginConfig.class); String requestUri = "/"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = get(requestUri).build(); this.request.addHeader("X-Requested-With", "XMLHttpRequest"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getRedirectedUrl()).doesNotMatch("http://localhost/oauth2/authorization/google"); @@ -515,8 +506,7 @@ public class OAuth2LoginConfigurerTests { throws Exception { loadConfig(OAuth2LoginWithHttpBasicConfig.class); String requestUri = "/"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = get(requestUri).build(); this.request.addHeader("X-Requested-With", "XMLHttpRequest"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getStatus()).isEqualTo(401); @@ -527,8 +517,7 @@ public class OAuth2LoginConfigurerTests { throws Exception { loadConfig(OAuth2LoginWithXHREntryPointConfig.class); String requestUri = "/"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = get(requestUri).build(); this.request.addHeader("X-Requested-With", "XMLHttpRequest"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getStatus()).isEqualTo(401); @@ -540,8 +529,7 @@ public class OAuth2LoginConfigurerTests { throws Exception { loadConfig(OAuth2LoginConfigAuthorizationCodeClientAndOtherClients.class); String requestUri = "/"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = get(requestUri).build(); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getRedirectedUrl()).matches("http://localhost/oauth2/authorization/google"); } @@ -550,8 +538,7 @@ public class OAuth2LoginConfigurerTests { public void oauth2LoginWithCustomLoginPageThenRedirectCustomLoginPage() throws Exception { loadConfig(OAuth2LoginConfigCustomLoginPage.class); String requestUri = "/"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = get(requestUri).build(); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getRedirectedUrl()).matches("http://localhost/custom-login"); } @@ -560,8 +547,7 @@ public class OAuth2LoginConfigurerTests { public void requestWhenOauth2LoginWithCustomLoginPageInLambdaThenRedirectCustomLoginPage() throws Exception { loadConfig(OAuth2LoginConfigCustomLoginPageInLambda.class); String requestUri = "/"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = get(requestUri).build(); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getRedirectedUrl()).matches("http://localhost/custom-login"); } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerConfigurationTests.java index df28e06209..30907b5e8d 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerConfigurationTests.java @@ -89,6 +89,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests for {@link OidcUserRefreshedEventListener} with {@link OAuth2LoginConfigurer}. @@ -147,8 +148,7 @@ public class OidcUserRefreshedEventListenerConfigurationTests { @BeforeEach public void setUp() { - this.request = new MockHttpServletRequest("GET", ""); - this.request.setServletPath("/"); + this.request = get("/").build(); this.response = new MockHttpServletResponse(); RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(this.request, this.response)); } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerTests.java index 6b8f82a8bd..84c7a3e42b 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerTests.java @@ -42,6 +42,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests for {@link OidcUserRefreshedEventListener}. @@ -64,8 +65,7 @@ public class OidcUserRefreshedEventListenerTests { this.eventListener = new OidcUserRefreshedEventListener(); this.eventListener.setSecurityContextRepository(this.securityContextRepository); - this.request = new MockHttpServletRequest("GET", ""); - this.request.setServletPath("/"); + this.request = get("/").build(); this.response = new MockHttpServletResponse(); } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java index f43ae69d86..102bf1e517 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java @@ -94,6 +94,7 @@ import org.springframework.security.web.authentication.AuthenticationFailureHand import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.HttpSessionSecurityContextRepository; import org.springframework.security.web.context.SecurityContextRepository; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; @@ -190,8 +191,7 @@ public class Saml2LoginConfigurerTests { @BeforeEach public void setup() { - this.request = new MockHttpServletRequest("POST", ""); - this.request.setServletPath("/login/saml2/sso/test-rp"); + this.request = TestMockHttpServletRequests.post("/login/saml2/sso/test-rp").build(); this.response = new MockHttpServletResponse(); this.filterChain = new MockFilterChain(); } @@ -430,7 +430,6 @@ public class Saml2LoginConfigurerTests { private void performSaml2Login(String expected) throws IOException, ServletException { // setup authentication parameters this.request.setRequestURI("/login/saml2/sso/registration-id"); - this.request.setServletPath("/login/saml2/sso/registration-id"); this.request.setParameter("SAMLResponse", Base64.getEncoder().encodeToString("saml2-xml-response-object".getBytes())); // perform test diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurerTests.java index 72c638807f..573e5331a0 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurerTests.java @@ -77,6 +77,7 @@ import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.authentication.logout.LogoutFilter; import org.springframework.security.web.authentication.logout.LogoutHandler; import org.springframework.security.web.authentication.logout.LogoutSuccessHandler; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import org.springframework.security.web.servlet.util.matcher.PathPatternRequestMatcher; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; @@ -159,8 +160,7 @@ public class Saml2LogoutConfigurerTests { Collections.emptyMap()); principal.setRelyingPartyRegistrationId("registration-id"); this.user = new Saml2Authentication(principal, "response", AuthorityUtils.createAuthorityList("ROLE_USER")); - this.request = new MockHttpServletRequest("POST", ""); - this.request.setServletPath("/login/saml2/sso/test-rp"); + this.request = TestMockHttpServletRequests.post("/login/saml2/sso/test-rp").build(); this.response = new MockHttpServletResponse(); } diff --git a/config/src/test/java/org/springframework/security/config/http/FilterSecurityMetadataSourceBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/http/FilterSecurityMetadataSourceBeanDefinitionParserTests.java index dd6c1264d3..5bd2d50f7a 100644 --- a/config/src/test/java/org/springframework/security/config/http/FilterSecurityMetadataSourceBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/http/FilterSecurityMetadataSourceBeanDefinitionParserTests.java @@ -129,9 +129,7 @@ public class FilterSecurityMetadataSourceBeanDefinitionParserTests { } private FilterInvocation createFilterInvocation(String path, String method) { - MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); - request.setRequestURI(path); - request.setMethod(method); + MockHttpServletRequest request = new MockHttpServletRequest(method, path); return new FilterInvocation(request, new MockHttpServletResponse(), new MockFilterChain()); } diff --git a/config/src/test/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParserTests.java index 4045c1b689..71da35955b 100644 --- a/config/src/test/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParserTests.java @@ -134,8 +134,7 @@ public class Saml2LogoutBeanDefinitionParserTests { principal.setRelyingPartyRegistrationId("registration-id"); this.saml2User = new Saml2Authentication(principal, "response", AuthorityUtils.createAuthorityList("ROLE_USER")); - this.request = new MockHttpServletRequest("POST", ""); - this.request.setServletPath("/login/saml2/sso/test-rp"); + this.request = new MockHttpServletRequest("POST", "/login/saml2/sso/test-rp"); this.response = new MockHttpServletResponse(); } diff --git a/config/src/test/java/org/springframework/security/config/http/SessionManagementConfigServlet31Tests.java b/config/src/test/java/org/springframework/security/config/http/SessionManagementConfigServlet31Tests.java index 03c50c9db9..333c2db61f 100644 --- a/config/src/test/java/org/springframework/security/config/http/SessionManagementConfigServlet31Tests.java +++ b/config/src/test/java/org/springframework/security/config/http/SessionManagementConfigServlet31Tests.java @@ -26,10 +26,7 @@ import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.config.util.InMemoryXmlApplicationContext; -import org.springframework.security.core.Authentication; -import org.springframework.security.core.context.SecurityContextImpl; -import org.springframework.security.web.context.HttpRequestResponseHolder; -import org.springframework.security.web.context.HttpSessionSecurityContextRepository; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import static org.assertj.core.api.Assertions.assertThat; @@ -61,7 +58,7 @@ public class SessionManagementConfigServlet31Tests { @BeforeEach public void setup() { - this.request = new MockHttpServletRequest("GET", ""); + this.request = new MockHttpServletRequest(); this.response = new MockHttpServletResponse(); this.chain = new MockFilterChain(); } @@ -75,12 +72,11 @@ public class SessionManagementConfigServlet31Tests { @Test public void changeSessionIdThenPreserveParameters() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); + MockHttpServletRequest request = TestMockHttpServletRequests.post("/login") + .param("username", "user") + .param("password", "password") + .build(); request.getSession(); - request.setServletPath("/login"); - request.setMethod("POST"); - request.setParameter("username", "user"); - request.setParameter("password", "password"); request.getSession().setAttribute("attribute1", "value1"); String id = request.getSession().getId(); // @formatter:off @@ -99,12 +95,11 @@ public class SessionManagementConfigServlet31Tests { @Test public void changeSessionId() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); + MockHttpServletRequest request = TestMockHttpServletRequests.post("/login") + .param("username", "user") + .param("password", "password") + .build(); request.getSession(); - request.setServletPath("/login"); - request.setMethod("POST"); - request.setParameter("username", "user"); - request.setParameter("password", "password"); String id = request.getSession().getId(); // @formatter:off loadContext("\n" @@ -124,13 +119,4 @@ public class SessionManagementConfigServlet31Tests { this.springSecurityFilterChain = this.context.getBean("springSecurityFilterChain", Filter.class); } - private void login(Authentication auth) { - HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository(); - HttpRequestResponseHolder requestResponseHolder = new HttpRequestResponseHolder(this.request, this.response); - repo.loadContext(requestResponseHolder); - SecurityContextImpl securityContextImpl = new SecurityContextImpl(); - securityContextImpl.setAuthentication(auth); - repo.saveContext(securityContextImpl, requestResponseHolder.getRequest(), requestResponseHolder.getResponse()); - } - } diff --git a/config/src/test/java/org/springframework/security/config/http/customconfigurer/CustomHttpSecurityConfigurerTests.java b/config/src/test/java/org/springframework/security/config/http/customconfigurer/CustomHttpSecurityConfigurerTests.java index c47bf68b6e..01d2f83c61 100644 --- a/config/src/test/java/org/springframework/security/config/http/customconfigurer/CustomHttpSecurityConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/http/customconfigurer/CustomHttpSecurityConfigurerTests.java @@ -60,7 +60,7 @@ public class CustomHttpSecurityConfigurerTests { @BeforeEach public void setup() { - this.request = new MockHttpServletRequest("GET", ""); + this.request = new MockHttpServletRequest(); this.response = new MockHttpServletResponse(); this.chain = new MockFilterChain(); this.request.setMethod("GET"); @@ -76,7 +76,7 @@ public class CustomHttpSecurityConfigurerTests { @Test public void customConfiguerPermitAll() throws Exception { loadContext(Config.class); - this.request.setPathInfo("/public/something"); + this.request.setRequestURI("/public/something"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); } @@ -84,7 +84,7 @@ public class CustomHttpSecurityConfigurerTests { @Test public void customConfiguerFormLogin() throws Exception { loadContext(Config.class); - this.request.setPathInfo("/requires-authentication"); + this.request.setRequestURI("/requires-authentication"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(this.response.getRedirectedUrl()).endsWith("/custom"); } @@ -92,7 +92,7 @@ public class CustomHttpSecurityConfigurerTests { @Test public void customConfiguerCustomizeDisablesCsrf() throws Exception { loadContext(ConfigCustomize.class); - this.request.setPathInfo("/public/something"); + this.request.setRequestURI("/public/something"); this.request.setMethod("POST"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); @@ -101,7 +101,7 @@ public class CustomHttpSecurityConfigurerTests { @Test public void customConfiguerCustomizeFormLogin() throws Exception { loadContext(ConfigCustomize.class); - this.request.setPathInfo("/requires-authentication"); + this.request.setRequestURI("/requires-authentication"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(this.response.getRedirectedUrl()).endsWith("/other"); } diff --git a/config/src/test/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParserTests.java index 68a6c22ab1..29e3cbbda1 100644 --- a/config/src/test/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParserTests.java @@ -40,6 +40,7 @@ import org.springframework.security.saml2.provider.service.web.authentication.Op import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.verify; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests for {@link RelyingPartyRegistrationsBeanDefinitionParser}. @@ -282,9 +283,7 @@ public class RelyingPartyRegistrationsBeanDefinitionParserTests { Converter relayStateResolver = this.spring.getContext().getBean(Converter.class); OpenSaml4AuthenticationRequestResolver authenticationRequestResolver = this.spring.getContext() .getBean(OpenSaml4AuthenticationRequestResolver.class); - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setRequestURI("/saml2/authenticate/one"); - request.setServletPath("/saml2/authenticate/one"); + MockHttpServletRequest request = get("/saml2/authenticate/one").build(); authenticationRequestResolver.resolve(request); verify(relayStateResolver).convert(request); } diff --git a/config/src/test/kotlin/org/springframework/security/config/annotation/web/AuthorizeRequestsDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/annotation/web/AuthorizeRequestsDslTests.kt index 42eda2bbfc..ed70c24915 100644 --- a/config/src/test/kotlin/org/springframework/security/config/annotation/web/AuthorizeRequestsDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/annotation/web/AuthorizeRequestsDslTests.kt @@ -44,8 +44,6 @@ import org.springframework.web.bind.annotation.PathVariable import org.springframework.web.bind.annotation.RequestMapping import org.springframework.web.bind.annotation.RestController import org.springframework.web.servlet.config.annotation.EnableWebMvc -import org.springframework.web.servlet.config.annotation.PathMatchConfigurer -import org.springframework.web.servlet.config.annotation.WebMvcConfigurer /** * Tests for [AuthorizeRequestsDsl] @@ -405,17 +403,11 @@ class AuthorizeRequestsDslTests { this.spring.register(MvcMatcherServletPathConfig::class.java).autowire() this.mockMvc.perform(MockMvcRequestBuilders.get("/spring/path") - .with { request -> - request.servletPath = "/spring" - request - }) + .servletPath("/spring")) .andExpect(status().isForbidden) this.mockMvc.perform(MockMvcRequestBuilders.get("/other/path") - .with { request -> - request.servletPath = "/other" - request - }) + .servletPath("/other")) .andExpect(status().isOk) } @@ -514,28 +506,15 @@ class AuthorizeRequestsDslTests { this.spring.register(MvcMatcherServletPathConfig::class.java).autowire() this.mockMvc.perform(MockMvcRequestBuilders.get("/spring/path") - .with { request -> - request.apply { - servletPath = "/spring" - } - }) + .servletPath("/spring")) .andExpect(status().isForbidden) this.mockMvc.perform(MockMvcRequestBuilders.put("/spring/path") - .with { request -> - request.apply { - servletPath = "/spring" - csrf() - } - }) + .servletPath("/spring")) .andExpect(status().isForbidden) this.mockMvc.perform(MockMvcRequestBuilders.get("/other/path") - .with { request -> - request.apply { - servletPath = "/other" - } - }) + .servletPath("/other")) .andExpect(status().isOk) } } diff --git a/config/src/test/kotlin/org/springframework/security/config/annotation/web/RequiresChannelDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/annotation/web/RequiresChannelDslTests.kt index 28632c1d92..439e64b663 100644 --- a/config/src/test/kotlin/org/springframework/security/config/annotation/web/RequiresChannelDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/annotation/web/RequiresChannelDslTests.kt @@ -83,18 +83,12 @@ class RequiresChannelDslTests { this.spring.register(MvcMatcherServletPathConfig::class.java).autowire() this.mockMvc.perform(MockMvcRequestBuilders.get("/spring/path") - .with { request -> - request.servletPath = "/spring" - request - }) + .servletPath("/spring")) .andExpect(status().isFound) .andExpect(redirectedUrl("https://localhost/spring/path")) this.mockMvc.perform(MockMvcRequestBuilders.get("/other/path") - .with { request -> - request.servletPath = "/other" - request - }) + .servletPath("/other")) .andExpect(MockMvcResultMatchers.status().isOk) } diff --git a/etc/checkstyle/checkstyle.xml b/etc/checkstyle/checkstyle.xml index e4cbc33131..297f8a8442 100644 --- a/etc/checkstyle/checkstyle.xml +++ b/etc/checkstyle/checkstyle.xml @@ -18,6 +18,7 @@ + diff --git a/itest/context/spring-security-itest-context.gradle b/itest/context/spring-security-itest-context.gradle index c278418f74..29e83a6161 100644 --- a/itest/context/spring-security-itest-context.gradle +++ b/itest/context/spring-security-itest-context.gradle @@ -9,7 +9,7 @@ dependencies { implementation 'org.springframework:spring-context' implementation 'org.springframework:spring-tx' - testImplementation project(':spring-security-web') + testImplementation project(path: ':spring-security-web', configuration: 'tests') testImplementation 'jakarta.servlet:jakarta.servlet-api' testImplementation 'org.springframework:spring-web' testImplementation "org.assertj:assertj-core" diff --git a/itest/context/src/integration-test/java/org/springframework/security/integration/HttpNamespaceWithMultipleInterceptorsTests.java b/itest/context/src/integration-test/java/org/springframework/security/integration/HttpNamespaceWithMultipleInterceptorsTests.java index aab6742b81..cf869c7ef3 100644 --- a/itest/context/src/integration-test/java/org/springframework/security/integration/HttpNamespaceWithMultipleInterceptorsTests.java +++ b/itest/context/src/integration-test/java/org/springframework/security/integration/HttpNamespaceWithMultipleInterceptorsTests.java @@ -43,9 +43,7 @@ public class HttpNamespaceWithMultipleInterceptorsTests { @Test public void requestThatIsMatchedByDefaultInterceptorIsAllowed() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setMethod("GET"); - request.setServletPath("/somefile.html"); + MockHttpServletRequest request = TestMockHttpServletRequests.get("/somefile.html").build(); request.setSession(createAuthenticatedSession("ROLE_0", "ROLE_1", "ROLE_2")); MockHttpServletResponse response = new MockHttpServletResponse(); this.fcp.doFilter(request, response, new MockFilterChain()); @@ -54,10 +52,7 @@ public class HttpNamespaceWithMultipleInterceptorsTests { @Test public void securedUrlAccessIsRejectedWithoutRequiredRole() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setMethod("GET"); - - request.setServletPath("/secure/somefile.html"); + MockHttpServletRequest request = TestMockHttpServletRequests.get("/secure/somefile.html").build(); request.setSession(createAuthenticatedSession("ROLE_0")); MockHttpServletResponse response = new MockHttpServletResponse(); this.fcp.doFilter(request, response, new MockFilterChain()); diff --git a/oauth2/oauth2-client/spring-security-oauth2-client.gradle b/oauth2/oauth2-client/spring-security-oauth2-client.gradle index 93bab342cf..11b6c91f0a 100644 --- a/oauth2/oauth2-client/spring-security-oauth2-client.gradle +++ b/oauth2/oauth2-client/spring-security-oauth2-client.gradle @@ -18,6 +18,7 @@ dependencies { testImplementation project(path: ':spring-security-oauth2-core', configuration: 'tests') testImplementation project(path: ':spring-security-oauth2-jose', configuration: 'tests') + testImplementation project(path: ':spring-security-web', configuration: 'tests') testImplementation 'com.squareup.okhttp3:mockwebserver' testImplementation 'io.micrometer:context-propagation' testImplementation 'io.projectreactor.netty:reactor-netty' diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java index 1382a0368f..92a170394a 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java @@ -44,6 +44,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.entry; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post; /** * Tests for {@link DefaultOAuth2AuthorizationRequestResolver}. @@ -123,8 +125,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { @Test public void resolveWhenNotAuthorizationRequestThenDoesNotResolve() { String requestUri = "/path"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest).isNull(); } @@ -133,7 +134,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { @Test public void resolveWhenNotAuthorizationRequestThenRequestBodyNotConsumed() throws IOException { String requestUri = "/path"; - MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); + MockHttpServletRequest request = post(requestUri).build(); request.setContent("foo".getBytes(StandardCharsets.UTF_8)); request.setCharacterEncoding(StandardCharsets.UTF_8.name()); HttpServletRequest spyRequest = Mockito.spy(request); @@ -151,8 +152,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { ClientRegistration clientRegistration = this.registration1; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId() + "-invalid"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); // @formatter:off assertThatIllegalArgumentException() .isThrownBy(() -> this.resolver.resolve(request)) @@ -164,8 +164,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestWithValidClientThenResolves() { ClientRegistration clientRegistration = this.registration1; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest).isNotNull(); assertThat(authorizationRequest.getAuthorizationUri()) @@ -191,8 +190,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenClientAuthorizationRequiredExceptionAvailableThenResolves() { ClientRegistration clientRegistration = this.registration2; String requestUri = "/path"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request, clientRegistration.getRegistrationId()); assertThat(authorizationRequest).isNotNull(); @@ -204,8 +202,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenRedirectUriExpanded() { ClientRegistration clientRegistration = this.registration2; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); assertThat(authorizationRequest.getRedirectUri()) @@ -216,9 +213,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenHttpRedirectUriWithExtraVarsExpanded() { ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServerPort(8080); - request.setServletPath(requestUri); + MockHttpServletRequest request = get("localhost:8080" + requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); assertThat(authorizationRequest.getRedirectUri()) @@ -229,10 +224,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenHttpsRedirectUriWithExtraVarsExpanded() { ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setScheme("https"); - request.setServerPort(8081); - request.setServletPath(requestUri); + MockHttpServletRequest request = get("https://localhost:8081" + requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); assertThat(authorizationRequest.getRedirectUri()) @@ -243,10 +235,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestIncludesPort80ThenExpandedRedirectUriWithExtraVarsExcludesPort() { ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setScheme("http"); - request.setServerPort(80); - request.setServletPath(requestUri); + MockHttpServletRequest request = get("http://localhost" + requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); assertThat(authorizationRequest.getRedirectUri()) @@ -257,10 +246,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestIncludesPort443ThenExpandedRedirectUriWithExtraVarsExcludesPort() { ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setScheme("https"); - request.setServerPort(443); - request.setServletPath(requestUri); + MockHttpServletRequest request = get("https://localhost:443" + requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); assertThat(authorizationRequest.getRedirectUri()) @@ -271,10 +257,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestHasNoPortThenInvalidUrlException() { ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setScheme("https"); - request.setServerPort(-1); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).port(-1).build(); assertThatExceptionOfType(InvalidUrlException.class).isThrownBy(() -> this.resolver.resolve(request)); } @@ -283,9 +266,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenRedirectUriExpandedExcludesQueryString() { ClientRegistration clientRegistration = this.registration2; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.setQueryString("foo=bar"); + MockHttpServletRequest request = get(requestUri + "?foo=bar").build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); assertThat(authorizationRequest.getRedirectUri()) @@ -296,11 +277,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestIncludesPort80ThenExpandedRedirectUriExcludesPort() { ClientRegistration clientRegistration = this.registration1; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setScheme("http"); - request.setServerName("localhost"); - request.setServerPort(80); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getAuthorizationRequestUri()) .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" @@ -312,11 +289,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestIncludesPort443ThenExpandedRedirectUriExcludesPort() { ClientRegistration clientRegistration = this.registration1; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setScheme("https"); - request.setServerName("example.com"); - request.setServerPort(443); - request.setServletPath(requestUri); + MockHttpServletRequest request = get("https://example.com:443" + requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getAuthorizationRequestUri()) .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" @@ -328,8 +301,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenClientAuthorizationRequiredExceptionAvailableThenRedirectUriIsAuthorize() { ClientRegistration clientRegistration = this.registration1; String requestUri = "/path"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request, clientRegistration.getRegistrationId()); assertThat(authorizationRequest.getAuthorizationRequestUri()) @@ -342,8 +314,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestOAuth2LoginThenRedirectUriIsLogin() { ClientRegistration clientRegistration = this.registration2; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getAuthorizationRequestUri()) .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id-2&" @@ -355,9 +326,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestHasActionParameterAuthorizeThenRedirectUriIsAuthorize() { ClientRegistration clientRegistration = this.registration1; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.addParameter("action", "authorize"); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).param("action", "authorize").build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getAuthorizationRequestUri()) .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" @@ -369,9 +338,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestHasActionParameterLoginThenRedirectUriIsLogin() { ClientRegistration clientRegistration = this.registration2; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.addParameter("action", "login"); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).param("action", "login").build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getAuthorizationRequestUri()) .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id-2&" @@ -383,8 +350,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestWithValidPublicClientThenResolves() { ClientRegistration clientRegistration = this.publicClientRegistration; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest).isNotNull(); assertThat(authorizationRequest.getAuthorizationUri()) @@ -420,15 +386,13 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { ClientRegistration clientRegistration = this.registration1; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertPkceApplied(authorizationRequest, clientRegistration); clientRegistration = this.registration2; requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + request = get(requestUri).build(); authorizationRequest = this.resolver.resolve(request); assertPkceApplied(authorizationRequest, clientRegistration); } @@ -447,15 +411,13 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { ClientRegistration clientRegistration = this.registration1; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertPkceApplied(authorizationRequest, clientRegistration); clientRegistration = this.registration2; requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + request = get(requestUri).build(); authorizationRequest = this.resolver.resolve(request); assertPkceNotApplied(authorizationRequest, clientRegistration); } @@ -491,8 +453,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthenticationRequestWithValidOidcClientThenResolves() { ClientRegistration clientRegistration = this.oidcRegistration; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest).isNotNull(); assertThat(authorizationRequest.getAuthorizationUri()) @@ -524,8 +485,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestCustomizerRemovesNonceThenQueryExcludesNonce() { ClientRegistration clientRegistration = this.oidcRegistration; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); this.resolver.setAuthorizationRequestCustomizer( (builder) -> builder.additionalParameters((params) -> params.remove(OidcParameterNames.NONCE)) .attributes((attrs) -> attrs.remove(OidcParameterNames.NONCE))); @@ -543,8 +503,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestCustomizerAddsParameterThenQueryIncludesParameter() { ClientRegistration clientRegistration = this.oidcRegistration; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); this.resolver.setAuthorizationRequestCustomizer((builder) -> builder.authorizationRequestUri((uriBuilder) -> { uriBuilder.queryParam("param1", "value1"); return uriBuilder.build(); @@ -561,8 +520,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestCustomizerOverridesParameterThenQueryIncludesParameter() { ClientRegistration clientRegistration = this.oidcRegistration; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); this.resolver.setAuthorizationRequestCustomizer((builder) -> builder.parameters((params) -> { params.put("appid", params.get("client_id")); params.remove("client_id"); @@ -579,8 +537,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { OAuth2AuthorizationRequestResolver resolver = new DefaultOAuth2AuthorizationRequestResolver( this.clientRegistrationRepository); String requestUri = this.authorizationRequestBaseUri + "/" + this.registration2.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = resolver.resolve(request); assertThat(authorizationRequest.getRedirectUri()) .isEqualTo("http://localhost/login/oauth2/code/" + this.registration2.getRegistrationId()); @@ -590,8 +547,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestProvideCodeChallengeMethod() { ClientRegistration clientRegistration = this.pkceClientRegistration; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getAdditionalParameters().containsKey(PkceParameterNames.CODE_CHALLENGE_METHOD)) .isTrue(); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java index a6f1f9bc96..440b3a6383 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java @@ -72,6 +72,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests for {@link OAuth2AuthorizationCodeGrantFilter}. @@ -154,8 +155,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests { @Test public void doFilterWhenNotAuthorizationResponseThenNotProcessed() throws Exception { String requestUri = "/path"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); // NOTE: A valid Authorization Response contains either a 'code' or 'error' // parameter. MockHttpServletResponse response = new MockHttpServletResponse(); @@ -328,8 +328,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests { @Test public void doFilterWhenAuthorizationSucceedsAndHasSavedRequestThenRedirectToSavedRequest() throws Exception { String requestUri = "/saved-request"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); MockHttpServletResponse response = new MockHttpServletResponse(); RequestCache requestCache = new HttpSessionRequestCache(); requestCache.saveRequest(request, response); @@ -430,8 +429,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests { private static MockHttpServletRequest createAuthorizationRequest(String requestUri, Map parameters) { - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); if (!CollectionUtils.isEmpty(parameters)) { parameters.forEach(request::addParameter); request.setQueryString(parameters.entrySet() diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java index 59676b1461..bc51ba9e7b 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java @@ -55,6 +55,7 @@ import static org.mockito.BDDMockito.willThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests for {@link OAuth2AuthorizationRequestRedirectFilter}. @@ -127,8 +128,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { @Test public void doFilterWhenNotAuthorizationRequestThenNextFilter() throws Exception { String requestUri = "/path"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.doFilter(request, response, filterChain); @@ -139,8 +139,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { public void doFilterWhenAuthorizationRequestWithInvalidClientThenStatusInternalServerError() throws Exception { String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" + this.registration1.getRegistrationId() + "-invalid"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.doFilter(request, response, filterChain); @@ -154,8 +153,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { throws Exception { String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" + this.registration1.getRegistrationId() + "-invalid"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.setAuthenticationFailureHandler((request1, response1, ex) -> { @@ -178,8 +176,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { public void doFilterWhenAuthorizationRequestOAuth2LoginThenRedirectForAuthorization() throws Exception { String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" + this.registration1.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.doFilter(request, response, filterChain); @@ -193,8 +190,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { public void doFilterWhenAuthorizationRequestOAuth2LoginThenAuthorizationRequestSaved() throws Exception { String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" + this.registration2.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); AuthorizationRequestRepository authorizationRequestRepository = mock( @@ -212,8 +208,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { this.filter = new OAuth2AuthorizationRequestRedirectFilter(this.clientRegistrationRepository, authorizationRequestBaseUri); String requestUri = authorizationRequestBaseUri + "/" + this.registration1.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.doFilter(request, response, filterChain); @@ -227,8 +222,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { public void doFilterWhenNotAuthorizationRequestAndClientAuthorizationRequiredExceptionThrownThenRedirectForAuthorization() throws Exception { String requestUri = "/path"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); willThrow(new ClientAuthorizationRequiredException(this.registration1.getRegistrationId())).given(filterChain) @@ -245,8 +239,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { public void doFilterWhenNotAuthorizationRequestAndClientAuthorizationRequiredExceptionThrownButAuthorizationRequestNotResolvedThenStatusInternalServerError() throws Exception { String requestUri = "/path"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); willThrow(new ClientAuthorizationRequiredException(this.registration1.getRegistrationId())).given(filterChain) @@ -266,8 +259,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { throws Exception { String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" + this.registration1.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); request.addParameter("idp", "https://other.provider.com"); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -295,8 +287,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { throws Exception { String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" + this.registration1.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); String loginHintParamName = "login_hint"; request.addParameter(loginHintParamName, "user@provider.com"); MockHttpServletResponse response = new MockHttpServletResponse(); @@ -335,8 +326,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { throws Exception { String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" + this.registration1.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); RedirectStrategy customRedirectStrategy = (httpRequest, httpResponse, url) -> { @@ -363,8 +353,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { public void doFilterWhenNotAuthorizationRequestAndClientAuthorizationRequiredExceptionThrownThenSaveRequestBeforeCommitted() throws Exception { String requestUri = "/path"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); willAnswer((invocation) -> assertThat((invocation.getArgument(1)).isCommitted()).isFalse()) diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java index 3dee3c0cc0..5e7b5a4211 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java @@ -69,6 +69,7 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests for {@link OAuth2LoginAuthenticationFilter}. @@ -163,8 +164,7 @@ public class OAuth2LoginAuthenticationFilterTests { @Test public void doFilterWhenNotAuthorizationResponseThenNextFilter() throws Exception { String requestUri = "/path"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.doFilter(request, response, filterChain); @@ -176,8 +176,7 @@ public class OAuth2LoginAuthenticationFilterTests { @Test public void doFilterWhenAuthorizationResponseInvalidThenInvalidRequestError() throws Exception { String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); // NOTE: // A valid Authorization Response contains either a 'code' or 'error' parameter. // Don't set it to force an invalid Authorization Response. @@ -198,10 +197,9 @@ public class OAuth2LoginAuthenticationFilterTests { public void doFilterWhenAuthorizationResponseAuthorizationRequestNotFoundThenAuthorizationRequestNotFoundError() throws Exception { String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, "state"); + MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, "state") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.doFilter(request, response, filterChain); @@ -221,10 +219,9 @@ public class OAuth2LoginAuthenticationFilterTests { throws Exception { String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); String state = "state"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, "state"); + MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, state) + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); // @formatter:off @@ -258,10 +255,9 @@ public class OAuth2LoginAuthenticationFilterTests { public void doFilterWhenAuthorizationResponseValidThenAuthorizationRequestRemoved() throws Exception { String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); String state = "state"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, state); + MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, state) + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.setUpAuthorizationRequest(request, response, this.registration2, state); @@ -274,10 +270,9 @@ public class OAuth2LoginAuthenticationFilterTests { public void doFilterWhenAuthorizationResponseValidThenAuthorizedClientSaved() throws Exception { String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId(); String state = "state"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, state); + MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, state) + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.setUpAuthorizationRequest(request, response, this.registration1, state); @@ -300,10 +295,9 @@ public class OAuth2LoginAuthenticationFilterTests { this.filter.setAuthenticationManager(this.authenticationManager); String requestUri = "/login/oauth2/custom/" + this.registration2.getRegistrationId(); String state = "state"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, state); + MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, state) + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.setUpAuthorizationRequest(request, response, this.registration2, state); @@ -319,13 +313,9 @@ public class OAuth2LoginAuthenticationFilterTests { throws Exception { String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); String state = "state"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setScheme("http"); - request.setServerName("localhost"); - request.setServerPort(80); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, "state"); + MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, state) + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.setUpAuthorizationRequest(request, response, this.registration2, state); @@ -350,13 +340,10 @@ public class OAuth2LoginAuthenticationFilterTests { throws Exception { String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); String state = "state"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setScheme("https"); - request.setServerName("example.com"); - request.setServerPort(443); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, "state"); + MockHttpServletRequest request = get("https://example.com:443" + requestUri) + .param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, state) + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.setUpAuthorizationRequest(request, response, this.registration2, state); @@ -381,13 +368,10 @@ public class OAuth2LoginAuthenticationFilterTests { throws Exception { String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); String state = "state"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setScheme("https"); - request.setServerName("example.com"); - request.setServerPort(9090); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, "state"); + MockHttpServletRequest request = get("https://example.com:9090" + requestUri) + .param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, state) + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.setUpAuthorizationRequest(request, response, this.registration2, state); @@ -411,10 +395,9 @@ public class OAuth2LoginAuthenticationFilterTests { public void attemptAuthenticationShouldSetAuthenticationDetailsOnAuthenticationResult() throws Exception { String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId(); String state = "state"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, state); + MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, state) + .build(); WebAuthenticationDetails webAuthenticationDetails = mock(WebAuthenticationDetails.class); given(this.authenticationDetailsSource.buildDetails(any())).willReturn(webAuthenticationDetails); MockHttpServletResponse response = new MockHttpServletResponse(); @@ -430,10 +413,9 @@ public class OAuth2LoginAuthenticationFilterTests { this.filter.setAuthenticationResultConverter((authentication) -> null); String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId(); String state = "state"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, state); + MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, state) + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); this.setUpAuthorizationRequest(request, response, this.registration1, state); this.setUpAuthenticationResult(this.registration1); @@ -448,10 +430,9 @@ public class OAuth2LoginAuthenticationFilterTests { authentication.getAuthorities(), authentication.getClientRegistration().getRegistrationId())); String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId(); String state = "state"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, state); + MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, state) + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); this.setUpAuthorizationRequest(request, response, this.registration1, state); this.setUpAuthenticationResult(this.registration1); diff --git a/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle b/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle index b05c1bbd57..64511c5079 100644 --- a/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle +++ b/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle @@ -108,6 +108,7 @@ dependencies { optional 'com.fasterxml.jackson.core:jackson-databind' optional 'org.springframework:spring-jdbc' + testImplementation project(path: ':spring-security-web', configuration: 'tests') testImplementation 'com.squareup.okhttp3:mockwebserver' testImplementation "org.assertj:assertj-core" testImplementation "org.skyscreamer:jsonassert" diff --git a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/OpenSaml4AuthenticationTokenConverterTests.java b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/OpenSaml4AuthenticationTokenConverterTests.java index 57f4221260..1635b4aa53 100644 --- a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/OpenSaml4AuthenticationTokenConverterTests.java +++ b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/OpenSaml4AuthenticationTokenConverterTests.java @@ -42,6 +42,7 @@ import org.springframework.security.saml2.provider.service.authentication.TestOp import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import org.springframework.util.StreamUtils; import org.springframework.web.util.UriUtils; @@ -216,15 +217,11 @@ public final class OpenSaml4AuthenticationTokenConverterTests { } private MockHttpServletRequest post(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("POST", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.post(uri).build(); } private MockHttpServletRequest get(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.get(uri).build(); } private T signed(T toSign) { diff --git a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverterTests.java b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverterTests.java index abcf3a4c78..ffb2196836 100644 --- a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverterTests.java +++ b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverterTests.java @@ -42,6 +42,7 @@ import org.springframework.security.saml2.provider.service.authentication.TestOp import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import org.springframework.util.StreamUtils; import org.springframework.web.util.UriUtils; @@ -216,15 +217,11 @@ public final class OpenSamlAuthenticationTokenConverterTests { } private MockHttpServletRequest post(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("POST", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.post(uri).build(); } private MockHttpServletRequest get(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.get(uri).build(); } private T signed(T toSign) { diff --git a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolverTests.java b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolverTests.java index 7ba18fb985..a1a172a8ed 100644 --- a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolverTests.java +++ b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolverTests.java @@ -103,9 +103,7 @@ public class OpenSaml4AuthenticationRequestResolverTests { } private MockHttpServletRequest givenRequest(String path) { - MockHttpServletRequest request = new MockHttpServletRequest("GET", path); - request.setServletPath(path); - return request; + return new MockHttpServletRequest("GET", path); } } diff --git a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestValidatorParametersResolverTests.java b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestValidatorParametersResolverTests.java index c7aeb8b878..cdf9cd7712 100644 --- a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestValidatorParametersResolverTests.java +++ b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestValidatorParametersResolverTests.java @@ -36,6 +36,7 @@ import org.springframework.security.saml2.provider.service.authentication.logout import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -135,15 +136,11 @@ public final class OpenSaml4LogoutRequestValidatorParametersResolverTests { } private MockHttpServletRequest post(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("POST", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.post(uri).build(); } private MockHttpServletRequest get(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.get(uri).build(); } private String serialize(XMLObject object) { diff --git a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestValidatorParametersResolverTests.java b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestValidatorParametersResolverTests.java index b57db0a895..e036b749cc 100644 --- a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestValidatorParametersResolverTests.java +++ b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestValidatorParametersResolverTests.java @@ -36,6 +36,7 @@ import org.springframework.security.saml2.provider.service.authentication.logout import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -135,15 +136,11 @@ public final class OpenSamlLogoutRequestValidatorParametersResolverTests { } private MockHttpServletRequest post(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("POST", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.post(uri).build(); } private MockHttpServletRequest get(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.get(uri).build(); } private String serialize(XMLObject object) { diff --git a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/OpenSaml5AuthenticationTokenConverterTests.java b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/OpenSaml5AuthenticationTokenConverterTests.java index 1c35ec58e3..dcf7617c6b 100644 --- a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/OpenSaml5AuthenticationTokenConverterTests.java +++ b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/OpenSaml5AuthenticationTokenConverterTests.java @@ -42,6 +42,7 @@ import org.springframework.security.saml2.provider.service.authentication.TestOp import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import org.springframework.util.StreamUtils; import org.springframework.web.util.UriUtils; @@ -216,15 +217,11 @@ public final class OpenSaml5AuthenticationTokenConverterTests { } private MockHttpServletRequest post(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("POST", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.post(uri).build(); } private MockHttpServletRequest get(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.get(uri).build(); } private T signed(T toSign) { diff --git a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5AuthenticationRequestResolverTests.java b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5AuthenticationRequestResolverTests.java index 38b44d5b55..c0163f7a10 100644 --- a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5AuthenticationRequestResolverTests.java +++ b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5AuthenticationRequestResolverTests.java @@ -28,6 +28,7 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import org.springframework.security.web.servlet.util.matcher.PathPatternRequestMatcher; import static org.assertj.core.api.Assertions.assertThat; @@ -103,9 +104,7 @@ public class OpenSaml5AuthenticationRequestResolverTests { } private MockHttpServletRequest givenRequest(String path) { - MockHttpServletRequest request = new MockHttpServletRequest("GET", path); - request.setServletPath(path); - return request; + return TestMockHttpServletRequests.get(path).build(); } } diff --git a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutRequestValidatorParametersResolverTests.java b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutRequestValidatorParametersResolverTests.java index 8ec0cef306..af13bef179 100644 --- a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutRequestValidatorParametersResolverTests.java +++ b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutRequestValidatorParametersResolverTests.java @@ -36,6 +36,7 @@ import org.springframework.security.saml2.provider.service.authentication.logout import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -135,15 +136,11 @@ public final class OpenSaml5LogoutRequestValidatorParametersResolverTests { } private MockHttpServletRequest post(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("POST", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.post(uri).build(); } private MockHttpServletRequest get(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.get(uri).build(); } private String serialize(XMLObject object) { diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/RequestMatcherMetadataResponseResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/RequestMatcherMetadataResponseResolverTests.java index 0684218ffd..1145cca686 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/RequestMatcherMetadataResponseResolverTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/RequestMatcherMetadataResponseResolverTests.java @@ -30,6 +30,7 @@ import org.springframework.security.saml2.provider.service.registration.InMemory import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -121,9 +122,7 @@ public final class RequestMatcherMetadataResponseResolverTests { } private MockHttpServletRequest get(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.get(uri).build(); } private RelyingPartyRegistration withEntityId(String entityId) { diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutRequestFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutRequestFilterTests.java index cd4d88f62a..d24ba9d736 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutRequestFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutRequestFilterTests.java @@ -46,6 +46,7 @@ import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post; /** * Tests for {@link Saml2LogoutRequestFilter} @@ -76,9 +77,8 @@ public class Saml2LogoutRequestFilterTests { RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); - request.setServletPath("/logout/saml2/slo"); - request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request"); + MockHttpServletRequest request = post("/logout/saml2/slo").param(Saml2ParameterNames.SAML_REQUEST, "request") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); given(this.relyingPartyRegistrationResolver.resolve(any(), any())).willReturn(registration); given(this.logoutRequestValidator.validate(any())).willReturn(Saml2LogoutValidatorResult.success()); @@ -105,9 +105,8 @@ public class Saml2LogoutRequestFilterTests { given(this.securityContextHolderStrategy.getContext()).willReturn(new SecurityContextImpl(authentication)); this.logoutRequestProcessingFilter.setSecurityContextHolderStrategy(this.securityContextHolderStrategy); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); - request.setServletPath("/logout/saml2/slo"); - request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request"); + MockHttpServletRequest request = post("/logout/saml2/slo").param(Saml2ParameterNames.SAML_REQUEST, "request") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); given(this.relyingPartyRegistrationResolver.resolve(any(), any())).willReturn(registration); given(this.logoutRequestValidator.validate(any())).willReturn(Saml2LogoutValidatorResult.success()); @@ -127,9 +126,7 @@ public class Saml2LogoutRequestFilterTests { public void doFilterWhenRequestMismatchesThenNoLogout() throws Exception { Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout"); - request.setServletPath("/logout"); - request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "response"); + MockHttpServletRequest request = post("/logout").param(Saml2ParameterNames.SAML_RESPONSE, "response").build(); MockHttpServletResponse response = new MockHttpServletResponse(); this.logoutRequestProcessingFilter.doFilterInternal(request, response, new MockFilterChain()); verifyNoInteractions(this.logoutRequestValidator, this.logoutHandler); @@ -139,8 +136,7 @@ public class Saml2LogoutRequestFilterTests { public void doFilterWhenNoSamlRequestOrResponseThenNoLogout() throws Exception { Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); - request.setServletPath("/logout/saml2/slo"); + MockHttpServletRequest request = post("/logout/saml2/slo").build(); MockHttpServletResponse response = new MockHttpServletResponse(); this.logoutRequestProcessingFilter.doFilterInternal(request, response, new MockFilterChain()); verifyNoInteractions(this.logoutRequestValidator, this.logoutHandler); @@ -153,9 +149,8 @@ public class Saml2LogoutRequestFilterTests { .build(); Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); - request.setServletPath("/logout/saml2/slo"); - request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request"); + MockHttpServletRequest request = post("/logout/saml2/slo").param(Saml2ParameterNames.SAML_REQUEST, "request") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); Saml2LogoutResponse logoutResponse = Saml2LogoutResponse.withRelyingPartyRegistration(registration) .samlResponse("response") @@ -182,7 +177,6 @@ public class Saml2LogoutRequestFilterTests { Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); - request.setServletPath("/logout/saml2/slo"); request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request"); MockHttpServletResponse response = new MockHttpServletResponse(); RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full() @@ -210,9 +204,8 @@ public class Saml2LogoutRequestFilterTests { public void doFilterWhenInvalidBindingErrorLogoutResponseIsPosted() throws Exception { Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); - request.setServletPath("/logout/saml2/slo"); - request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request"); + MockHttpServletRequest request = post("/logout/saml2/slo").param(Saml2ParameterNames.SAML_REQUEST, "request") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full() .assertingPartyMetadata((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST)) @@ -242,9 +235,8 @@ public class Saml2LogoutRequestFilterTests { public void doFilterWhenNoErrorResponseCanBeGeneratedThen401() throws Exception { Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); - request.setServletPath("/logout/saml2/slo"); - request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request"); + MockHttpServletRequest request = post("/logout/saml2/slo").param(Saml2ParameterNames.SAML_REQUEST, "request") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full() .assertingPartyMetadata((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST)) diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutResponseFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutResponseFilterTests.java index 5973f9589e..77f43dad57 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutResponseFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutResponseFilterTests.java @@ -43,6 +43,8 @@ import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.mock; import static org.mockito.BDDMockito.verify; import static org.mockito.Mockito.verifyNoInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post; /** * Tests for {@link Saml2LogoutResponseFilter} @@ -74,9 +76,8 @@ public class Saml2LogoutResponseFilterTests { public void doFilterWhenSamlResponsePostThenLogout() throws Exception { Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); - request.setServletPath("/logout/saml2/slo"); - request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "response"); + MockHttpServletRequest request = post("/logout/saml2/slo").param(Saml2ParameterNames.SAML_RESPONSE, "response") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); given(this.relyingPartyRegistrationResolver.resolve(request, "registration-id")).willReturn(registration); @@ -94,8 +95,7 @@ public class Saml2LogoutResponseFilterTests { public void doFilterWhenSamlResponseRedirectThenLogout() throws Exception { Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("GET", "/logout/saml2/slo"); - request.setServletPath("/logout/saml2/slo"); + MockHttpServletRequest request = get("/logout/saml2/slo").build(); request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "response"); MockHttpServletResponse response = new MockHttpServletResponse(); RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full() @@ -116,9 +116,7 @@ public class Saml2LogoutResponseFilterTests { public void doFilterWhenRequestMismatchesThenNoLogout() throws Exception { Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout"); - request.setServletPath("/logout"); - request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request"); + MockHttpServletRequest request = post("/logout").param(Saml2ParameterNames.SAML_REQUEST, "request").build(); MockHttpServletResponse response = new MockHttpServletResponse(); this.logoutResponseProcessingFilter.doFilterInternal(request, response, new MockFilterChain()); verifyNoInteractions(this.logoutResponseValidator, this.logoutSuccessHandler); @@ -128,8 +126,7 @@ public class Saml2LogoutResponseFilterTests { public void doFilterWhenNoSamlRequestOrResponseThenNoLogout() throws Exception { Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); - request.setServletPath("/logout/saml2/slo"); + MockHttpServletRequest request = post("/logout/saml2/slo").build(); MockHttpServletResponse response = new MockHttpServletResponse(); this.logoutResponseProcessingFilter.doFilterInternal(request, response, new MockFilterChain()); verifyNoInteractions(this.logoutResponseValidator, this.logoutSuccessHandler); @@ -139,9 +136,8 @@ public class Saml2LogoutResponseFilterTests { public void doFilterWhenValidatorFailsThenStops() throws Exception { Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); - request.setServletPath("/logout/saml2/slo"); - request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "response"); + MockHttpServletRequest request = post("/logout/saml2/slo").param(Saml2ParameterNames.SAML_RESPONSE, "response") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); given(this.relyingPartyRegistrationResolver.resolve(request, "registration-id")).willReturn(registration); @@ -160,9 +156,8 @@ public class Saml2LogoutResponseFilterTests { public void doFilterWhenNoRelyingPartyLogoutThen401() throws Exception { Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); - request.setServletPath("/logout/saml2/slo"); - request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "response"); + MockHttpServletRequest request = post("/logout/saml2/slo").param(Saml2ParameterNames.SAML_RESPONSE, "response") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full() .singleLogoutServiceLocation(null) diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2RelyingPartyInitiatedLogoutSuccessHandlerTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2RelyingPartyInitiatedLogoutSuccessHandlerTests.java index 2823991574..862ea2548c 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2RelyingPartyInitiatedLogoutSuccessHandlerTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2RelyingPartyInitiatedLogoutSuccessHandlerTests.java @@ -39,6 +39,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.mock; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post; /** * Tests for {@link Saml2RelyingPartyInitiatedLogoutSuccessHandler} @@ -72,8 +73,7 @@ public class Saml2RelyingPartyInitiatedLogoutSuccessHandlerTests { Saml2LogoutRequest logoutRequest = Saml2LogoutRequest.withRelyingPartyRegistration(registration) .samlRequest("request") .build(); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/saml2/logout"); - request.setServletPath("/saml2/logout"); + MockHttpServletRequest request = post("/saml2/logout").build(); MockHttpServletResponse response = new MockHttpServletResponse(); given(this.logoutRequestResolver.resolve(any(), any())).willReturn(logoutRequest); this.logoutRequestSuccessHandler.onLogoutSuccess(request, response, authentication); @@ -92,8 +92,7 @@ public class Saml2RelyingPartyInitiatedLogoutSuccessHandlerTests { Saml2LogoutRequest logoutRequest = Saml2LogoutRequest.withRelyingPartyRegistration(registration) .samlRequest("request") .build(); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/saml2/logout"); - request.setServletPath("/saml2/logout"); + MockHttpServletRequest request = post("/saml2/logout").build(); MockHttpServletResponse response = new MockHttpServletResponse(); given(this.logoutRequestResolver.resolve(any(), any())).willReturn(logoutRequest); this.logoutRequestSuccessHandler.onLogoutSuccess(request, response, authentication); diff --git a/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java b/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java index 2e8f7a552a..a1b1a5300f 100644 --- a/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java +++ b/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java @@ -64,6 +64,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * @author Luke Taylor @@ -96,8 +97,7 @@ public class FilterChainProxyTests { }).given(this.filter).doFilter(any(), any(), any()); this.fcp = new FilterChainProxy(new DefaultSecurityFilterChain(this.matcher, Arrays.asList(this.filter))); this.fcp.setFilterChainValidator(mock(FilterChainProxy.FilterChainValidator.class)); - this.request = new MockHttpServletRequest("GET", ""); - this.request.setServletPath("/path"); + this.request = get("/path").build(); this.response = new MockHttpServletResponse(); this.chain = mock(FilterChain.class); } diff --git a/web/src/test/java/org/springframework/security/web/FilterInvocationTests.java b/web/src/test/java/org/springframework/security/web/FilterInvocationTests.java index 5b6fc258ea..58a6fd2f98 100644 --- a/web/src/test/java/org/springframework/security/web/FilterInvocationTests.java +++ b/web/src/test/java/org/springframework/security/web/FilterInvocationTests.java @@ -34,6 +34,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.Mockito.mock; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests {@link FilterInvocation}. @@ -45,14 +46,8 @@ public class FilterInvocationTests { @Test public void testGettersAndStringMethods() { - MockHttpServletRequest request = new MockHttpServletRequest(null, null); - request.setServletPath("/HelloWorld"); - request.setPathInfo("/some/more/segments.html"); - request.setServerName("localhost"); - request.setScheme("http"); - request.setServerPort(80); - request.setContextPath("/mycontext"); - request.setRequestURI("/mycontext/HelloWorld/some/more/segments.html"); + MockHttpServletRequest request = get().requestUri("/mycontext", "/HelloWorld", "/some/more/segments.html") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); FilterInvocation fi = new FilterInvocation(request, response, chain); @@ -62,7 +57,7 @@ public class FilterInvocationTests { assertThat(fi.getHttpResponse()).isEqualTo(response); assertThat(fi.getChain()).isEqualTo(chain); assertThat(fi.getRequestUrl()).isEqualTo("/HelloWorld/some/more/segments.html"); - assertThat(fi.toString()).isEqualTo("filter invocation [/HelloWorld/some/more/segments.html]"); + assertThat(fi.toString()).isEqualTo("filter invocation [GET /HelloWorld/some/more/segments.html]"); assertThat(fi.getFullRequestUrl()).isEqualTo("http://localhost/mycontext/HelloWorld/some/more/segments.html"); } @@ -89,34 +84,23 @@ public class FilterInvocationTests { @Test public void testStringMethodsWithAQueryString() { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setQueryString("foo=bar"); - request.setServletPath("/HelloWorld"); - request.setServerName("localhost"); - request.setScheme("http"); - request.setServerPort(80); - request.setContextPath("/mycontext"); - request.setRequestURI("/mycontext/HelloWorld"); + MockHttpServletRequest request = get().requestUri("/mycontext", "/HelloWorld", null) + .queryString("foo=bar") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterInvocation fi = new FilterInvocation(request, response, mock(FilterChain.class)); assertThat(fi.getRequestUrl()).isEqualTo("/HelloWorld?foo=bar"); - assertThat(fi.toString()).isEqualTo("filter invocation [/HelloWorld?foo=bar]"); + assertThat(fi.toString()).isEqualTo("filter invocation [GET /HelloWorld?foo=bar]"); assertThat(fi.getFullRequestUrl()).isEqualTo("http://localhost/mycontext/HelloWorld?foo=bar"); } @Test public void testStringMethodsWithoutAnyQueryString() { - MockHttpServletRequest request = new MockHttpServletRequest(null, null); - request.setServletPath("/HelloWorld"); - request.setServerName("localhost"); - request.setScheme("http"); - request.setServerPort(80); - request.setContextPath("/mycontext"); - request.setRequestURI("/mycontext/HelloWorld"); + MockHttpServletRequest request = get().requestUri("/mycontext", "/HelloWorld", null).build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterInvocation fi = new FilterInvocation(request, response, mock(FilterChain.class)); assertThat(fi.getRequestUrl()).isEqualTo("/HelloWorld"); - assertThat(fi.toString()).isEqualTo("filter invocation [/HelloWorld]"); + assertThat(fi.toString()).isEqualTo("filter invocation [GET /HelloWorld]"); assertThat(fi.getFullRequestUrl()).isEqualTo("http://localhost/mycontext/HelloWorld"); } diff --git a/web/src/test/java/org/springframework/security/web/RequestMatcherRedirectFilterTests.java b/web/src/test/java/org/springframework/security/web/RequestMatcherRedirectFilterTests.java index 1340c36365..b03e6d4b07 100644 --- a/web/src/test/java/org/springframework/security/web/RequestMatcherRedirectFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/RequestMatcherRedirectFilterTests.java @@ -29,6 +29,7 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests for {@link RequestMatcherRedirectFilter}. @@ -44,9 +45,7 @@ public class RequestMatcherRedirectFilterTests { RequestMatcherRedirectFilter filter = new RequestMatcherRedirectFilter(this.builder.matcher("/context"), "/test"); - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath("/context"); - + MockHttpServletRequest request = get("/context").build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -63,8 +62,7 @@ public class RequestMatcherRedirectFilterTests { RequestMatcherRedirectFilter filter = new RequestMatcherRedirectFilter(this.builder.matcher("/context"), "/test"); - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath("/test"); + MockHttpServletRequest request = get("/test").build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); diff --git a/web/src/test/java/org/springframework/security/web/access/ExceptionTranslationFilterTests.java b/web/src/test/java/org/springframework/security/web/access/ExceptionTranslationFilterTests.java index 89159fd737..11fe3e41d2 100644 --- a/web/src/test/java/org/springframework/security/web/access/ExceptionTranslationFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/access/ExceptionTranslationFilterTests.java @@ -58,6 +58,7 @@ import static org.mockito.BDDMockito.willThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests {@link ExceptionTranslationFilter}. @@ -86,13 +87,7 @@ public class ExceptionTranslationFilterTests { @Test public void testAccessDeniedWhenAnonymous() throws Exception { // Setup our HTTP request - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath("/secure/page.html"); - request.setServerPort(80); - request.setScheme("http"); - request.setServerName("localhost"); - request.setContextPath("/mycontext"); - request.setRequestURI("/mycontext/secure/page.html"); + MockHttpServletRequest request = get().requestUri("/mycontext", "/secure/page.html", null).build(); // Setup the FilterChain to thrown an access denied exception FilterChain fc = mockFilterChainWithException(new AccessDeniedException("")); // Setup SecurityContextHolder, as filter needs to check if user is @@ -129,13 +124,7 @@ public class ExceptionTranslationFilterTests { @Test public void testAccessDeniedWithRememberMe() throws Exception { // Setup our HTTP request - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath("/secure/page.html"); - request.setServerPort(80); - request.setScheme("http"); - request.setServerName("localhost"); - request.setContextPath("/mycontext"); - request.setRequestURI("/mycontext/secure/page.html"); + MockHttpServletRequest request = get().requestUri("/mycontext", "/secure/page.html", null).build(); // Setup the FilterChain to thrown an access denied exception FilterChain fc = mockFilterChainWithException(new AccessDeniedException("")); // Setup SecurityContextHolder, as filter needs to check if user is remembered @@ -155,8 +144,7 @@ public class ExceptionTranslationFilterTests { @Test public void testAccessDeniedWhenNonAnonymous() throws Exception { // Setup our HTTP request - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath("/secure/page.html"); + MockHttpServletRequest request = get("/secure/page.html").build(); // Setup the FilterChain to thrown an access denied exception FilterChain fc = mockFilterChainWithException(new AccessDeniedException("")); // Setup SecurityContextHolder, as filter needs to check if user is @@ -178,8 +166,7 @@ public class ExceptionTranslationFilterTests { @Test public void testLocalizedErrorMessages() throws Exception { // Setup our HTTP request - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath("/secure/page.html"); + MockHttpServletRequest request = get("/secure/page.html").build(); // Setup the FilterChain to thrown an access denied exception FilterChain fc = mockFilterChainWithException(new AccessDeniedException("")); // Setup SecurityContextHolder, as filter needs to check if user is @@ -202,13 +189,7 @@ public class ExceptionTranslationFilterTests { @Test public void redirectedToLoginFormAndSessionShowsOriginalTargetWhenAuthenticationException() throws Exception { // Setup our HTTP request - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath("/secure/page.html"); - request.setServerPort(80); - request.setScheme("http"); - request.setServerName("localhost"); - request.setContextPath("/mycontext"); - request.setRequestURI("/mycontext/secure/page.html"); + MockHttpServletRequest request = get().requestUri("/mycontext", "/secure/page.html", null).build(); // Setup the FilterChain to thrown an authentication failure exception FilterChain fc = mockFilterChainWithException(new BadCredentialsException("")); // Test @@ -225,13 +206,9 @@ public class ExceptionTranslationFilterTests { public void redirectedToLoginFormAndSessionShowsOriginalTargetWithExoticPortWhenAuthenticationException() throws Exception { // Setup our HTTP request - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath("/secure/page.html"); - request.setServerPort(8080); - request.setScheme("http"); - request.setServerName("localhost"); - request.setContextPath("/mycontext"); - request.setRequestURI("/mycontext/secure/page.html"); + MockHttpServletRequest request = get("http://localhost:8080") + .requestUri("/mycontext", "/secure/page.html", null) + .build(); // Setup the FilterChain to thrown an authentication failure exception FilterChain fc = mockFilterChainWithException(new BadCredentialsException("")); // Test @@ -258,8 +235,7 @@ public class ExceptionTranslationFilterTests { @Test public void successfulAccessGrant() throws Exception { // Setup our HTTP request - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath("/secure/page.html"); + MockHttpServletRequest request = get("/secure/page.html").build(); // Test ExceptionTranslationFilter filter = new ExceptionTranslationFilter(this.mockEntryPoint); assertThat(filter.getAuthenticationEntryPoint()).isSameAs(this.mockEntryPoint); diff --git a/web/src/test/java/org/springframework/security/web/access/channel/ChannelProcessingFilterTests.java b/web/src/test/java/org/springframework/security/web/access/channel/ChannelProcessingFilterTests.java index 348b3b2152..ad3f3afa66 100644 --- a/web/src/test/java/org/springframework/security/web/access/channel/ChannelProcessingFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/access/channel/ChannelProcessingFilterTests.java @@ -32,6 +32,7 @@ import org.springframework.security.web.access.intercept.FilterInvocationSecurit import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.Mockito.mock; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests {@link ChannelProcessingFilter}. @@ -81,9 +82,8 @@ public class ChannelProcessingFilterTests { filter.setChannelDecisionManager(new MockChannelDecisionManager(true, "SOME_ATTRIBUTE")); MockFilterInvocationDefinitionMap fids = new MockFilterInvocationDefinitionMap("/path", true, "SOME_ATTRIBUTE"); filter.setSecurityMetadataSource(fids); - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/path").build(); request.setQueryString("info=now"); - request.setServletPath("/path"); MockHttpServletResponse response = new MockHttpServletResponse(); filter.doFilter(request, response, mock(FilterChain.class)); } @@ -94,9 +94,8 @@ public class ChannelProcessingFilterTests { filter.setChannelDecisionManager(new MockChannelDecisionManager(false, "SOME_ATTRIBUTE")); MockFilterInvocationDefinitionMap fids = new MockFilterInvocationDefinitionMap("/path", true, "SOME_ATTRIBUTE"); filter.setSecurityMetadataSource(fids); - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/path").build(); request.setQueryString("info=now"); - request.setServletPath("/path"); MockHttpServletResponse response = new MockHttpServletResponse(); filter.doFilter(request, response, mock(FilterChain.class)); } @@ -107,9 +106,8 @@ public class ChannelProcessingFilterTests { filter.setChannelDecisionManager(new MockChannelDecisionManager(false, "NOT_USED")); MockFilterInvocationDefinitionMap fids = new MockFilterInvocationDefinitionMap("/path", true, "NOT_USED"); filter.setSecurityMetadataSource(fids); - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/PATH_NOT_MATCHING_CONFIG_ATTRIBUTE").build(); request.setQueryString("info=now"); - request.setServletPath("/PATH_NOT_MATCHING_CONFIG_ATTRIBUTE"); MockHttpServletResponse response = new MockHttpServletResponse(); filter.doFilter(request, response, mock(FilterChain.class)); } diff --git a/web/src/test/java/org/springframework/security/web/access/channel/InsecureChannelProcessorTests.java b/web/src/test/java/org/springframework/security/web/access/channel/InsecureChannelProcessorTests.java index 1a3f8f1480..1c290cac06 100644 --- a/web/src/test/java/org/springframework/security/web/access/channel/InsecureChannelProcessorTests.java +++ b/web/src/test/java/org/springframework/security/web/access/channel/InsecureChannelProcessorTests.java @@ -27,6 +27,7 @@ import org.springframework.security.web.FilterInvocation; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.Mockito.mock; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests {@link InsecureChannelProcessor}. @@ -37,13 +38,9 @@ public class InsecureChannelProcessorTests { @Test public void testDecideDetectsAcceptableChannel() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setQueryString("info=true"); - request.setServerName("localhost"); - request.setContextPath("/bigapp"); - request.setServletPath("/servlet"); - request.setScheme("http"); - request.setServerPort(8080); + MockHttpServletRequest request = get("http://localhost:8080").requestUri("/bigapp", "/servlet", null) + .queryString("info=true") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterInvocation fi = new FilterInvocation(request, response, mock(FilterChain.class)); InsecureChannelProcessor processor = new InsecureChannelProcessor(); @@ -53,14 +50,9 @@ public class InsecureChannelProcessorTests { @Test public void testDecideDetectsUnacceptableChannel() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setQueryString("info=true"); - request.setServerName("localhost"); - request.setContextPath("/bigapp"); - request.setServletPath("/servlet"); - request.setScheme("https"); - request.setSecure(true); - request.setServerPort(8443); + MockHttpServletRequest request = get("https://localhost:8443").requestUri("/bigapp", "/servlet", null) + .queryString("info=true") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterInvocation fi = new FilterInvocation(request, response, mock(FilterChain.class)); InsecureChannelProcessor processor = new InsecureChannelProcessor(); diff --git a/web/src/test/java/org/springframework/security/web/access/channel/SecureChannelProcessorTests.java b/web/src/test/java/org/springframework/security/web/access/channel/SecureChannelProcessorTests.java index 005736b336..4263cec233 100644 --- a/web/src/test/java/org/springframework/security/web/access/channel/SecureChannelProcessorTests.java +++ b/web/src/test/java/org/springframework/security/web/access/channel/SecureChannelProcessorTests.java @@ -27,6 +27,7 @@ import org.springframework.security.web.FilterInvocation; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.Mockito.mock; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests {@link SecureChannelProcessor}. @@ -37,14 +38,9 @@ public class SecureChannelProcessorTests { @Test public void testDecideDetectsAcceptableChannel() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setQueryString("info=true"); - request.setServerName("localhost"); - request.setContextPath("/bigapp"); - request.setServletPath("/servlet"); - request.setScheme("https"); - request.setSecure(true); - request.setServerPort(8443); + MockHttpServletRequest request = get("https://localhost:8443").requestUri("/bigapp", "/servlet", null) + .queryString("info=true") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterInvocation fi = new FilterInvocation(request, response, mock(FilterChain.class)); SecureChannelProcessor processor = new SecureChannelProcessor(); @@ -54,13 +50,9 @@ public class SecureChannelProcessorTests { @Test public void testDecideDetectsUnacceptableChannel() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setQueryString("info=true"); - request.setServerName("localhost"); - request.setContextPath("/bigapp"); - request.setServletPath("/servlet"); - request.setScheme("http"); - request.setServerPort(8080); + MockHttpServletRequest request = get("http://localhost:8080").requestUri("/bigapp", "/servlet", null) + .queryString("info=true") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterInvocation fi = new FilterInvocation(request, response, mock(FilterChain.class)); SecureChannelProcessor processor = new SecureChannelProcessor(); diff --git a/web/src/test/java/org/springframework/security/web/access/expression/AbstractVariableEvaluationContextPostProcessorTests.java b/web/src/test/java/org/springframework/security/web/access/expression/AbstractVariableEvaluationContextPostProcessorTests.java index d9bb23e6a2..aec2145136 100644 --- a/web/src/test/java/org/springframework/security/web/access/expression/AbstractVariableEvaluationContextPostProcessorTests.java +++ b/web/src/test/java/org/springframework/security/web/access/expression/AbstractVariableEvaluationContextPostProcessorTests.java @@ -31,6 +31,7 @@ import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.web.FilterInvocation; import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * @author Rob Winch @@ -54,8 +55,7 @@ public class AbstractVariableEvaluationContextPostProcessorTests { @BeforeEach public void setup() { this.processor = new VariableEvaluationContextPostProcessor(); - this.request = new MockHttpServletRequest(); - this.request.setServletPath("/"); + this.request = get("/").build(); this.response = new MockHttpServletResponse(); this.invocation = new FilterInvocation(this.request, this.response, new MockFilterChain()); this.context = new StandardEvaluationContext(); diff --git a/web/src/test/java/org/springframework/security/web/access/intercept/DefaultFilterInvocationSecurityMetadataSourceTests.java b/web/src/test/java/org/springframework/security/web/access/intercept/DefaultFilterInvocationSecurityMetadataSourceTests.java index 5460be992c..1e9fac9ebe 100644 --- a/web/src/test/java/org/springframework/security/web/access/intercept/DefaultFilterInvocationSecurityMetadataSourceTests.java +++ b/web/src/test/java/org/springframework/security/web/access/intercept/DefaultFilterInvocationSecurityMetadataSourceTests.java @@ -33,6 +33,7 @@ import org.springframework.security.web.util.matcher.RequestMatcher; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.mock; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.request; /** * Tests {@link DefaultFilterInvocationSecurityMetadataSource}. @@ -141,12 +142,9 @@ public class DefaultFilterInvocationSecurityMetadataSourceTests { private FilterInvocation createFilterInvocation(String servletPath, String pathInfo, String queryString, String method) { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setRequestURI(null); - request.setMethod(method); - request.setServletPath(servletPath); - request.setPathInfo(pathInfo); - request.setQueryString(queryString); + MockHttpServletRequest request = request(method).requestUri(null, servletPath, pathInfo) + .queryString(queryString) + .build(); return new FilterInvocation(request, new MockHttpServletResponse(), mock(FilterChain.class)); } diff --git a/web/src/test/java/org/springframework/security/web/access/intercept/FilterSecurityInterceptorTests.java b/web/src/test/java/org/springframework/security/web/access/intercept/FilterSecurityInterceptorTests.java index 4408b14147..77370b0cd7 100644 --- a/web/src/test/java/org/springframework/security/web/access/intercept/FilterSecurityInterceptorTests.java +++ b/web/src/test/java/org/springframework/security/web/access/intercept/FilterSecurityInterceptorTests.java @@ -53,6 +53,7 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests {@link FilterSecurityInterceptor}. @@ -188,8 +189,7 @@ public class FilterSecurityInterceptorTests { private FilterInvocation createinvocation() { MockHttpServletResponse response = new MockHttpServletResponse(); - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath("/secure/page.html"); + MockHttpServletRequest request = get("/secure/page.html").build(); FilterChain chain = mock(FilterChain.class); FilterInvocation fi = new FilterInvocation(request, response, chain); return fi; diff --git a/web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java index 70cf88853b..1ba9b34187 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java @@ -59,6 +59,9 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.Builder; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post; /** * Tests {@link AbstractAuthenticationProcessingFilter}. @@ -75,13 +78,11 @@ public class AbstractAuthenticationProcessingFilterTests { SimpleUrlAuthenticationFailureHandler failureHandler; private MockHttpServletRequest createMockAuthenticationRequest() { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath("/j_mock_post"); - request.setScheme("http"); - request.setServerName("www.example.com"); - request.setRequestURI("/mycontext/j_mock_post"); - request.setContextPath("/mycontext"); - return request; + return withMockAuthenticationRequest().build(); + } + + private Builder withMockAuthenticationRequest() { + return get("www.example.com").requestUri("/mycontext", "/j_mock_post", null); } @BeforeEach @@ -100,12 +101,11 @@ public class AbstractAuthenticationProcessingFilterTests { @Test public void testDefaultProcessesFilterUrlMatchesWithPathParameter() { - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/login;jsessionid=I8MIONOSTHOR"); + MockHttpServletRequest request = post("/login;jsessionid=I8MIONOSTHOR").build(); MockHttpServletResponse response = new MockHttpServletResponse(); MockAuthenticationFilter filter = new MockAuthenticationFilter(); filter.setFilterProcessesUrl("/login"); DefaultHttpFirewall firewall = new DefaultHttpFirewall(); - request.setServletPath("/login;jsessionid=I8MIONOSTHOR"); // the firewall ensures that path parameters are ignored HttpServletRequest firewallRequest = firewall.getFirewalledRequest(request); assertThat(filter.requiresAuthentication(firewallRequest, response)).isTrue(); @@ -114,9 +114,9 @@ public class AbstractAuthenticationProcessingFilterTests { @Test public void testFilterProcessesUrlVariationsRespected() throws Exception { // Setup our HTTP request - MockHttpServletRequest request = createMockAuthenticationRequest(); - request.setServletPath("/j_OTHER_LOCATION"); - request.setRequestURI("/mycontext/j_OTHER_LOCATION"); + MockHttpServletRequest request = withMockAuthenticationRequest() + .requestUri("/mycontext", "/j_OTHER_LOCATION", null) + .build(); // Setup our filter configuration MockFilterConfig config = new MockFilterConfig(null, null); // Setup our expectation that the filter chain will not be invoked, as we redirect @@ -150,9 +150,9 @@ public class AbstractAuthenticationProcessingFilterTests { @Test public void testIgnoresAnyServletPathOtherThanFilterProcessesUrl() throws Exception { // Setup our HTTP request - MockHttpServletRequest request = createMockAuthenticationRequest(); - request.setServletPath("/some.file.html"); - request.setRequestURI("/mycontext/some.file.html"); + MockHttpServletRequest request = withMockAuthenticationRequest() + .requestUri("/mycontext", "/some.file.html", null) + .build(); // Setup our filter configuration MockFilterConfig config = new MockFilterConfig(null, null); // Setup our expectation that the filter chain will be invoked, as our request is @@ -227,9 +227,9 @@ public class AbstractAuthenticationProcessingFilterTests { @Test public void testNormalOperationWithRequestMatcherAndAuthenticationManager() throws Exception { // Setup our HTTP request - MockHttpServletRequest request = createMockAuthenticationRequest(); - request.setServletPath("/j_eradicate_corona_virus"); - request.setRequestURI("/mycontext/j_eradicate_corona_virus"); + MockHttpServletRequest request = withMockAuthenticationRequest() + .requestUri("/mycontext", "/j_eradicate_corona_virus", null) + .build(); HttpSession sessionPreAuth = request.getSession(); // Setup our filter configuration MockFilterConfig config = new MockFilterConfig(null, null); diff --git a/web/src/test/java/org/springframework/security/web/authentication/LoginUrlAuthenticationEntryPointTests.java b/web/src/test/java/org/springframework/security/web/authentication/LoginUrlAuthenticationEntryPointTests.java index 91e2d93cdf..a1483034c0 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/LoginUrlAuthenticationEntryPointTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/LoginUrlAuthenticationEntryPointTests.java @@ -28,6 +28,7 @@ import org.springframework.security.web.PortMapperImpl; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests {@link LoginUrlAuthenticationEntryPoint}. @@ -73,12 +74,7 @@ public class LoginUrlAuthenticationEntryPointTests { @Test public void testHttpsOperationFromOriginalHttpUrl() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setRequestURI("/some_path"); - request.setScheme("http"); - request.setServerName("www.example.com"); - request.setContextPath("/bigWebApp"); - request.setServerPort(80); + MockHttpServletRequest request = get("http://127.0.0.1").requestUri("/bigWebApp", "/some_path", null).build(); MockHttpServletResponse response = new MockHttpServletResponse(); LoginUrlAuthenticationEntryPoint ep = new LoginUrlAuthenticationEntryPoint("/hello"); ep.setPortMapper(new PortMapperImpl()); @@ -87,17 +83,17 @@ public class LoginUrlAuthenticationEntryPointTests { ep.setPortResolver(new MockPortResolver(80, 443)); ep.afterPropertiesSet(); ep.commence(request, response, null); - assertThat(response.getRedirectedUrl()).isEqualTo("https://www.example.com/bigWebApp/hello"); + assertThat(response.getRedirectedUrl()).isEqualTo("https://127.0.0.1/bigWebApp/hello"); request.setServerPort(8080); response = new MockHttpServletResponse(); ep.setPortResolver(new MockPortResolver(8080, 8443)); ep.commence(request, response, null); - assertThat(response.getRedirectedUrl()).isEqualTo("https://www.example.com:8443/bigWebApp/hello"); + assertThat(response.getRedirectedUrl()).isEqualTo("https://127.0.0.1:8443/bigWebApp/hello"); // Now test an unusual custom HTTP:HTTPS is handled properly request.setServerPort(8888); response = new MockHttpServletResponse(); ep.commence(request, response, null); - assertThat(response.getRedirectedUrl()).isEqualTo("https://www.example.com:8443/bigWebApp/hello"); + assertThat(response.getRedirectedUrl()).isEqualTo("https://127.0.0.1:8443/bigWebApp/hello"); PortMapperImpl portMapper = new PortMapperImpl(); Map map = new HashMap<>(); map.put("8888", "9999"); @@ -110,17 +106,13 @@ public class LoginUrlAuthenticationEntryPointTests { ep.setPortResolver(new MockPortResolver(8888, 9999)); ep.afterPropertiesSet(); ep.commence(request, response, null); - assertThat(response.getRedirectedUrl()).isEqualTo("https://www.example.com:9999/bigWebApp/hello"); + assertThat(response.getRedirectedUrl()).isEqualTo("https://127.0.0.1:9999/bigWebApp/hello"); } @Test public void testHttpsOperationFromOriginalHttpsUrl() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setRequestURI("/some_path"); - request.setScheme("https"); - request.setServerName("www.example.com"); - request.setContextPath("/bigWebApp"); - request.setServerPort(443); + MockHttpServletRequest request = get("https://www.example.com:443").requestUri("/bigWebApp", "/some_path", null) + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); LoginUrlAuthenticationEntryPoint ep = new LoginUrlAuthenticationEntryPoint("/hello"); ep.setPortMapper(new PortMapperImpl()); @@ -149,13 +141,7 @@ public class LoginUrlAuthenticationEntryPointTests { ep.setPortMapper(new PortMapperImpl()); ep.setPortResolver(new MockPortResolver(80, 443)); ep.afterPropertiesSet(); - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setRequestURI("/some_path"); - request.setContextPath("/bigWebApp"); - request.setScheme("http"); - request.setServerName("localhost"); - request.setContextPath("/bigWebApp"); - request.setServerPort(80); + MockHttpServletRequest request = get().requestUri("/bigWebApp", "/some_path", null).build(); MockHttpServletResponse response = new MockHttpServletResponse(); ep.commence(request, response, null); assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/bigWebApp/hello"); @@ -167,13 +153,8 @@ public class LoginUrlAuthenticationEntryPointTests { ep.setPortResolver(new MockPortResolver(8888, 1234)); ep.setForceHttps(true); ep.afterPropertiesSet(); - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setRequestURI("/some_path"); - request.setContextPath("/bigWebApp"); - request.setScheme("http"); - request.setServerName("localhost"); - request.setContextPath("/bigWebApp"); - request.setServerPort(8888); // NB: Port we can't resolve + MockHttpServletRequest request = get("http://localhost:8888").requestUri("/bigWebApp", "/some_path", null) + .build(); // NB: Port we can't resolve MockHttpServletResponse response = new MockHttpServletResponse(); ep.commence(request, response, null); // Response doesn't switch to HTTPS, as we didn't know HTTP port 8888 to HTTP port @@ -186,14 +167,7 @@ public class LoginUrlAuthenticationEntryPointTests { LoginUrlAuthenticationEntryPoint ep = new LoginUrlAuthenticationEntryPoint("/hello"); ep.setUseForward(true); ep.afterPropertiesSet(); - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setRequestURI("/bigWebApp/some_path"); - request.setServletPath("/some_path"); - request.setContextPath("/bigWebApp"); - request.setScheme("http"); - request.setServerName("www.example.com"); - request.setContextPath("/bigWebApp"); - request.setServerPort(80); + MockHttpServletRequest request = get().requestUri("/bigWebApp", "/some_path", null).build(); MockHttpServletResponse response = new MockHttpServletResponse(); ep.commence(request, response, null); assertThat(response.getForwardedUrl()).isEqualTo("/hello"); @@ -205,17 +179,10 @@ public class LoginUrlAuthenticationEntryPointTests { ep.setUseForward(true); ep.setForceHttps(true); ep.afterPropertiesSet(); - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setRequestURI("/bigWebApp/some_path"); - request.setServletPath("/some_path"); - request.setContextPath("/bigWebApp"); - request.setScheme("http"); - request.setServerName("www.example.com"); - request.setContextPath("/bigWebApp"); - request.setServerPort(80); + MockHttpServletRequest request = get("http://127.0.0.1").requestUri("/bigWebApp", "/some_path", null).build(); MockHttpServletResponse response = new MockHttpServletResponse(); ep.commence(request, response, null); - assertThat(response.getRedirectedUrl()).isEqualTo("https://www.example.com/bigWebApp/some_path"); + assertThat(response.getRedirectedUrl()).isEqualTo("https://127.0.0.1/bigWebApp/some_path"); } // SEC-1498 diff --git a/web/src/test/java/org/springframework/security/web/authentication/RequestMatcherDelegatingAuthenticationManagerResolverTests.java b/web/src/test/java/org/springframework/security/web/authentication/RequestMatcherDelegatingAuthenticationManagerResolverTests.java index 4dbf8acccc..3362b108e9 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/RequestMatcherDelegatingAuthenticationManagerResolverTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/RequestMatcherDelegatingAuthenticationManagerResolverTests.java @@ -28,6 +28,7 @@ import org.springframework.security.web.servlet.util.matcher.PathPatternRequestM import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.Mockito.mock; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests for {@link RequestMatcherDelegatingAuthenticationManagerResolverTests} @@ -48,8 +49,7 @@ public class RequestMatcherDelegatingAuthenticationManagerResolverTests { .add(PathPatternRequestMatcher.withDefaults().matcher("/two/**"), this.two) .build(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", "/one/location"); - request.setServletPath("/one/location"); + MockHttpServletRequest request = get("/one/location").build(); assertThat(resolver.resolve(request)).isEqualTo(this.one); } diff --git a/web/src/test/java/org/springframework/security/web/authentication/UsernamePasswordAuthenticationFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/UsernamePasswordAuthenticationFilterTests.java index 04f372d71d..d43144f47a 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/UsernamePasswordAuthenticationFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/UsernamePasswordAuthenticationFilterTests.java @@ -39,6 +39,7 @@ import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post; /** * Tests {@link UsernamePasswordAuthenticationFilter}. @@ -128,10 +129,10 @@ public class UsernamePasswordAuthenticationFilterTests { @Test public void testSecurityContextHolderStrategyUsed() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/login"); - request.setServletPath("/login"); - request.addParameter(UsernamePasswordAuthenticationFilter.SPRING_SECURITY_FORM_USERNAME_KEY, "rod"); - request.addParameter(UsernamePasswordAuthenticationFilter.SPRING_SECURITY_FORM_PASSWORD_KEY, "koala"); + MockHttpServletRequest request = post("/login") + .param(UsernamePasswordAuthenticationFilter.SPRING_SECURITY_FORM_USERNAME_KEY, "rod") + .param(UsernamePasswordAuthenticationFilter.SPRING_SECURITY_FORM_PASSWORD_KEY, "koala") + .build(); UsernamePasswordAuthenticationFilter filter = new UsernamePasswordAuthenticationFilter(); filter.setAuthenticationManager(createAuthenticationManager()); SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy()); diff --git a/web/src/test/java/org/springframework/security/web/authentication/logout/LogoutHandlerTests.java b/web/src/test/java/org/springframework/security/web/authentication/logout/LogoutHandlerTests.java index 6039fd27a8..e489b5914e 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/logout/LogoutHandlerTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/logout/LogoutHandlerTests.java @@ -24,6 +24,8 @@ import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.web.firewall.DefaultHttpFirewall; import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post; /** * @author Luke Taylor @@ -39,22 +41,20 @@ public class LogoutHandlerTests { @Test public void testRequiresLogoutUrlWorksWithPathParams() { - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/context/logout;someparam=blah"); + MockHttpServletRequest request = post().requestUri("/context", "/logout;someparam=blah", null) + .queryString("otherparam=blah") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); - request.setContextPath("/context"); - request.setServletPath("/logout;someparam=blah"); - request.setQueryString("otherparam=blah"); DefaultHttpFirewall fw = new DefaultHttpFirewall(); assertThat(this.filter.requiresLogout(fw.getFirewalledRequest(request), response)).isTrue(); } @Test public void testRequiresLogoutUrlWorksWithQueryParams() { - MockHttpServletRequest request = new MockHttpServletRequest("GET", "/context/logout"); - request.setContextPath("/context"); + MockHttpServletRequest request = get().requestUri("/context", "/logout", null) + .queryString("otherparam=blah") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); - request.setServletPath("/logout"); - request.setQueryString("param=blah"); assertThat(this.filter.requiresLogout(request, response)).isTrue(); } diff --git a/web/src/test/java/org/springframework/security/web/authentication/ott/GenerateOneTimeTokenFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/ott/GenerateOneTimeTokenFilterTests.java index 4c7085ab70..0f32977543 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/ott/GenerateOneTimeTokenFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/ott/GenerateOneTimeTokenFilterTests.java @@ -38,6 +38,7 @@ import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post; /** * Tests for {@link GenerateOneTimeTokenWebFilter} @@ -55,7 +56,7 @@ public class GenerateOneTimeTokenFilterTests { private static final String USERNAME = "user"; - private final MockHttpServletRequest request = new MockHttpServletRequest(); + private MockHttpServletRequest request; private final MockHttpServletResponse response = new MockHttpServletResponse(); @@ -63,9 +64,7 @@ public class GenerateOneTimeTokenFilterTests { @BeforeEach void setup() { - this.request.setMethod("POST"); - this.request.setServletPath("/ott/generate"); - this.request.setRequestURI("/ott/generate"); + this.request = post("/ott/generate").build(); } @Test @@ -87,6 +86,7 @@ public class GenerateOneTimeTokenFilterTests { void filterWhenUsernameFormParamIsEmptyThenNull() throws ServletException, IOException { given(this.oneTimeTokenService.generate(ArgumentMatchers.any(GenerateOneTimeTokenRequest.class))) .willReturn((new DefaultOneTimeToken(TOKEN, USERNAME, Instant.now()))); + GenerateOneTimeTokenFilter filter = new GenerateOneTimeTokenFilter(this.oneTimeTokenService, this.successHandler); diff --git a/web/src/test/java/org/springframework/security/web/authentication/ui/DefaultOneTimeTokenSubmitPageGeneratingFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/ui/DefaultOneTimeTokenSubmitPageGeneratingFilterTests.java index ad38fa6f7c..18f9af5b6f 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/ui/DefaultOneTimeTokenSubmitPageGeneratingFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/ui/DefaultOneTimeTokenSubmitPageGeneratingFilterTests.java @@ -27,6 +27,7 @@ import org.springframework.mock.web.MockHttpServletResponse; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests for {@link DefaultOneTimeTokenSubmitPageGeneratingFilter} @@ -37,7 +38,7 @@ class DefaultOneTimeTokenSubmitPageGeneratingFilterTests { DefaultOneTimeTokenSubmitPageGeneratingFilter filter = new DefaultOneTimeTokenSubmitPageGeneratingFilter(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", "/login/ott"); + MockHttpServletRequest request; MockHttpServletResponse response = new MockHttpServletResponse(); @@ -45,9 +46,7 @@ class DefaultOneTimeTokenSubmitPageGeneratingFilterTests { @BeforeEach void setup() { - this.request.setMethod("GET"); - this.request.setServletPath("/login/ott"); - this.request.setRequestURI("/login/ott"); + this.request = get("/login/ott").build(); } @Test @@ -80,10 +79,9 @@ class DefaultOneTimeTokenSubmitPageGeneratingFilterTests { @Test void setContextThenGenerates() throws Exception { - this.request.setContextPath("/context"); - this.request.setRequestURI("/context/login/ott"); + MockHttpServletRequest request = get().requestUri("/context", "/login/ott", null).build(); this.filter.setLoginProcessingUrl("/login/another"); - this.filter.doFilterInternal(this.request, this.response, this.filterChain); + this.filter.doFilterInternal(request, this.response, this.filterChain); String response = this.response.getContentAsString(); assertThat(response).contains("
"); } @@ -101,7 +99,7 @@ class DefaultOneTimeTokenSubmitPageGeneratingFilterTests { void filterThenRenders() throws Exception { this.request.setParameter("token", "this<>!@#\""); this.filter.setLoginProcessingUrl("/login/another"); - this.filter.setResolveHiddenInputs((request) -> Map.of("_csrf", "csrf-token-value")); + this.filter.setResolveHiddenInputs((r) -> Map.of("_csrf", "csrf-token-value")); this.filter.doFilterInternal(this.request, this.response, this.filterChain); String response = this.response.getContentAsString(); assertThat(response).isEqualTo( diff --git a/web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java index 389756b41b..a05f53392c 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java @@ -61,6 +61,7 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests {@link BasicAuthenticationFilter}. @@ -94,8 +95,7 @@ public class BasicAuthenticationFilterTests { @Test public void testFilterIgnoresRequestsContainingNoAuthorizationHeader() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath("/some_file.html"); + MockHttpServletRequest request = get("/some_file.html").build(); final MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); this.filter.doFilter(request, response, chain); @@ -113,9 +113,8 @@ public class BasicAuthenticationFilterTests { @Test public void testInvalidBasicAuthorizationTokenIsIgnored() throws Exception { String token = "NOT_A_VALID_TOKEN_AS_MISSING_COLON"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); - request.setServletPath("/some_file.html"); request.setSession(new MockHttpSession()); final MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); @@ -127,9 +126,8 @@ public class BasicAuthenticationFilterTests { @Test public void invalidBase64IsIgnored() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "Basic NOT_VALID_BASE64"); - request.setServletPath("/some_file.html"); request.setSession(new MockHttpSession()); final MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); @@ -143,9 +141,8 @@ public class BasicAuthenticationFilterTests { @Test public void testNormalOperation() throws Exception { String token = "rod:koala"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); - request.setServletPath("/some_file.html"); // Test assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); FilterChain chain = mock(FilterChain.class); @@ -172,9 +169,8 @@ public class BasicAuthenticationFilterTests { @Test public void doFilterWhenSchemeLowercaseThenCaseInsensitveMatchWorks() throws Exception { String token = "rod:koala"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "basic " + CodecTestUtils.encodeBase64(token)); - request.setServletPath("/some_file.html"); // Test assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); FilterChain chain = mock(FilterChain.class); @@ -187,9 +183,8 @@ public class BasicAuthenticationFilterTests { @Test public void doFilterWhenSchemeMixedCaseThenCaseInsensitiveMatchWorks() throws Exception { String token = "rod:koala"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "BaSiC " + CodecTestUtils.encodeBase64(token)); - request.setServletPath("/some_file.html"); assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); FilterChain chain = mock(FilterChain.class); this.filter.doFilter(request, new MockHttpServletResponse(), chain); @@ -200,9 +195,8 @@ public class BasicAuthenticationFilterTests { @Test public void testOtherAuthorizationSchemeIsIgnored() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "SOME_OTHER_AUTHENTICATION_SCHEME"); - request.setServletPath("/some_file.html"); FilterChain chain = mock(FilterChain.class); this.filter.doFilter(request, new MockHttpServletResponse(), chain); verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class)); @@ -222,9 +216,8 @@ public class BasicAuthenticationFilterTests { @Test public void testSuccessLoginThenFailureLoginResultsInSessionLosingToken() throws Exception { String token = "rod:koala"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); - request.setServletPath("/some_file.html"); final MockHttpServletResponse response1 = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); this.filter.doFilter(request, response1, chain); @@ -240,7 +233,6 @@ public class BasicAuthenticationFilterTests { chain = mock(FilterChain.class); this.filter.doFilter(request, response2, chain); verify(chain, never()).doFilter(any(ServletRequest.class), any(ServletResponse.class)); - request.setServletPath("/some_file.html"); // Test - the filter chain will not be invoked, as we get a 401 forbidden response MockHttpServletResponse response = response2; assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); @@ -250,9 +242,8 @@ public class BasicAuthenticationFilterTests { @Test public void testWrongPasswordContinuesFilterChainIfIgnoreFailureIsTrue() throws Exception { String token = "rod:WRONG_PASSWORD"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); - request.setServletPath("/some_file.html"); request.setSession(new MockHttpSession()); this.filter = new BasicAuthenticationFilter(this.manager); assertThat(this.filter.isIgnoreFailure()).isTrue(); @@ -266,9 +257,8 @@ public class BasicAuthenticationFilterTests { @Test public void testWrongPasswordReturnsForbiddenIfIgnoreFailureIsFalse() throws Exception { String token = "rod:WRONG_PASSWORD"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); - request.setServletPath("/some_file.html"); request.setSession(new MockHttpSession()); assertThat(this.filter.isIgnoreFailure()).isFalse(); final MockHttpServletResponse response = new MockHttpServletResponse(); @@ -284,9 +274,8 @@ public class BasicAuthenticationFilterTests { @Test public void skippedOnErrorDispatch() throws Exception { String token = "bad:credentials"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); - request.setServletPath("/some_file.html"); request.setAttribute(WebUtils.ERROR_REQUEST_URI_ATTRIBUTE, "/error"); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); @@ -307,10 +296,9 @@ public class BasicAuthenticationFilterTests { given(this.manager.authenticate(not(eq(rodRequest)))).willThrow(new BadCredentialsException("")); this.filter = new BasicAuthenticationFilter(this.manager, new BasicAuthenticationEntryPoint()); String token = "rod:äöü"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token.getBytes(StandardCharsets.UTF_8))); - request.setServletPath("/some_file.html"); MockHttpServletResponse response = new MockHttpServletResponse(); // Test assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); @@ -336,10 +324,9 @@ public class BasicAuthenticationFilterTests { this.filter = new BasicAuthenticationFilter(this.manager, new BasicAuthenticationEntryPoint()); this.filter.setCredentialsCharset("ISO-8859-1"); String token = "rod:äöü"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token.getBytes(StandardCharsets.ISO_8859_1))); - request.setServletPath("/some_file.html"); MockHttpServletResponse response = new MockHttpServletResponse(); // Test assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); @@ -367,10 +354,9 @@ public class BasicAuthenticationFilterTests { this.filter = new BasicAuthenticationFilter(this.manager, new BasicAuthenticationEntryPoint()); this.filter.setCredentialsCharset("ISO-8859-1"); String token = "rod:äöü"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token.getBytes(StandardCharsets.UTF_8))); - request.setServletPath("/some_file.html"); MockHttpServletResponse response = new MockHttpServletResponse(); // Test assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); @@ -383,9 +369,8 @@ public class BasicAuthenticationFilterTests { @Test public void requestWhenEmptyBasicAuthorizationHeaderTokenThenUnauthorized() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "Basic "); - request.setServletPath("/some_file.html"); request.setSession(new MockHttpSession()); final MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); @@ -401,9 +386,8 @@ public class BasicAuthenticationFilterTests { SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class); this.filter.setSecurityContextRepository(securityContextRepository); String token = "rod:koala"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); - request.setServletPath("/some_file.html"); MockHttpServletResponse response = new MockHttpServletResponse(); // Test assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); @@ -496,9 +480,8 @@ public class BasicAuthenticationFilterTests { public void doFilterWhenCustomAuthenticationConverterThatIgnoresRequestThenIgnores() throws Exception { this.filter.setAuthenticationConverter(new TestAuthenticationConverter()); String token = "rod:koala"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/ignored").build(); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); - request.setServletPath("/ignored"); FilterChain filterChain = mock(FilterChain.class); MockHttpServletResponse response = new MockHttpServletResponse(); this.filter.doFilter(request, response, filterChain); @@ -513,9 +496,8 @@ public class BasicAuthenticationFilterTests { public void doFilterWhenCustomAuthenticationConverterRequestThenAuthenticate() throws Exception { this.filter.setAuthenticationConverter(new TestAuthenticationConverter()); String token = "rod:koala"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/ok").build(); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); - request.setServletPath("/ok"); FilterChain filterChain = mock(FilterChain.class); MockHttpServletResponse response = new MockHttpServletResponse(); this.filter.doFilter(request, response, filterChain); diff --git a/web/src/test/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilterTests.java index 230c554fcd..c3375abf29 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilterTests.java @@ -53,6 +53,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests {@link DigestAuthenticationFilter}. @@ -131,8 +132,7 @@ public class DigestAuthenticationFilterTests { this.filter = new DigestAuthenticationFilter(); this.filter.setUserDetailsService(uds); this.filter.setAuthenticationEntryPoint(ep); - this.request = new MockHttpServletRequest("GET", REQUEST_URI); - this.request.setServletPath(REQUEST_URI); + this.request = get(REQUEST_URI).build(); } @Test diff --git a/web/src/test/java/org/springframework/security/web/debug/DebugFilterTests.java b/web/src/test/java/org/springframework/security/web/debug/DebugFilterTests.java index 683db01360..1088f993e5 100644 --- a/web/src/test/java/org/springframework/security/web/debug/DebugFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/debug/DebugFilterTests.java @@ -41,6 +41,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * @author Rob Winch @@ -120,10 +121,7 @@ public class DebugFilterTests { @Test public void doFilterLogsProperly() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setMethod("GET"); - request.setServletPath("/path"); - request.setPathInfo("/"); + MockHttpServletRequest request = get().requestUri(null, "/path", "/").build(); request.addHeader("A", "A Value"); request.addHeader("A", "Another Value"); request.addHeader("B", "B Value"); diff --git a/web/src/test/java/org/springframework/security/web/firewall/DefaultHttpFirewallTests.java b/web/src/test/java/org/springframework/security/web/firewall/DefaultHttpFirewallTests.java index c7ce4d72f3..00f656816b 100644 --- a/web/src/test/java/org/springframework/security/web/firewall/DefaultHttpFirewallTests.java +++ b/web/src/test/java/org/springframework/security/web/firewall/DefaultHttpFirewallTests.java @@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test; import org.springframework.mock.web.MockHttpServletRequest; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * @author Luke Taylor @@ -34,8 +35,7 @@ public class DefaultHttpFirewallTests { public void unnormalizedPathsAreRejected() { DefaultHttpFirewall fw = new DefaultHttpFirewall(); for (String path : this.unnormalizedPaths) { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath(path); + MockHttpServletRequest request = get().requestUri(path).build(); assertThatExceptionOfType(RequestRejectedException.class) .isThrownBy(() -> fw.getFirewalledRequest(request)); request.setPathInfo(path); diff --git a/web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java b/web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java index f5406f0094..e7a46bfe54 100644 --- a/web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java +++ b/web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java @@ -27,6 +27,7 @@ import org.springframework.mock.web.MockHttpServletRequest; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * @author Rob Winch @@ -112,8 +113,7 @@ public class StrictHttpFirewallTests { @Test public void getFirewalledRequestWhenServletPathNotNormalizedThenThrowsRequestRejectedException() { for (String path : this.unnormalizedPaths) { - this.request = new MockHttpServletRequest("GET", ""); - this.request.setServletPath(path); + this.request = get().requestUri(path).build(); assertThatExceptionOfType(RequestRejectedException.class) .isThrownBy(() -> this.firewall.getFirewalledRequest(this.request)); } diff --git a/web/src/test/java/org/springframework/security/web/servlet/TestMockHttpServletRequests.java b/web/src/test/java/org/springframework/security/web/servlet/TestMockHttpServletRequests.java new file mode 100644 index 0000000000..017397640c --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/servlet/TestMockHttpServletRequests.java @@ -0,0 +1,169 @@ +/* + * Copyright 2004-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.servlet; + +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.function.Consumer; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.springframework.http.HttpMethod; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.util.StringUtils; + +public final class TestMockHttpServletRequests { + + private TestMockHttpServletRequests() { + + } + + public static Builder get() { + return new Builder(HttpMethod.GET); + } + + public static Builder get(String url) { + return get().applyUrl(url); + } + + public static Builder post() { + return new Builder(HttpMethod.POST); + } + + public static Builder post(String url) { + return post().applyUrl(url); + } + + public static Builder request(String method) { + return new Builder(HttpMethod.valueOf(method)); + } + + public static final class Builder { + + private static final Pattern URL = Pattern.compile("((?https?)://)?" + + "((?[^:/]+)(:(?\\d+))?)?" + "(?[^?]+)?" + "(\\?(?.*))?"); + + private final HttpMethod method; + + private String requestUri; + + private final Map parameters = new LinkedHashMap<>(); + + private String scheme = MockHttpServletRequest.DEFAULT_SCHEME; + + private int port = MockHttpServletRequest.DEFAULT_SERVER_PORT; + + private String hostname = MockHttpServletRequest.DEFAULT_SERVER_NAME; + + private String contextPath; + + private String servletPath; + + private String pathInfo; + + private String queryString; + + private Builder(HttpMethod method) { + this.method = method; + } + + private Builder applyUrl(String url) { + Matcher matcher = URL.matcher(url); + if (matcher.matches()) { + applyElement(this::scheme, matcher.group("scheme")); + applyElement(this::port, matcher.group("port")); + applyElement(this::serverName, matcher.group("hostname")); + applyElement(this::requestUri, matcher.group("path")); + applyElement(this::queryString, matcher.group("query")); + } + return this; + } + + private void applyElement(Consumer apply, T value) { + if (value != null) { + apply.accept(value); + } + } + + public Builder requestUri(String contextPath, String servletPath, String pathInfo) { + this.contextPath = contextPath; + this.servletPath = servletPath; + this.pathInfo = pathInfo; + this.requestUri = Stream.of(contextPath, servletPath, pathInfo) + .filter(StringUtils::hasText) + .collect(Collectors.joining()); + return this; + } + + public Builder requestUri(String requestUri) { + return requestUri(null, requestUri, null); + } + + public Builder param(String name, String value) { + this.parameters.put(name, value); + return this; + } + + private Builder port(String port) { + if (port != null) { + this.port = Integer.parseInt(port); + } + return this; + } + + public Builder port(int port) { + this.port = port; + return this; + } + + public Builder queryString(String queryString) { + this.queryString = queryString; + return this; + } + + public Builder scheme(String scheme) { + this.scheme = scheme; + return this; + } + + public Builder serverName(String serverName) { + this.hostname = serverName; + return this; + } + + public MockHttpServletRequest build() { + MockHttpServletRequest request = new MockHttpServletRequest(); + applyElement(request::setContextPath, this.contextPath); + applyElement(request::setContextPath, this.contextPath); + applyElement(request::setMethod, this.method.name()); + applyElement(request::setParameters, this.parameters); + applyElement(request::setPathInfo, this.pathInfo); + applyElement(request::setServletPath, this.servletPath); + applyElement(request::setScheme, this.scheme); + applyElement(request::setServerPort, this.port); + applyElement(request::setServerName, this.hostname); + applyElement(request::setQueryString, this.queryString); + applyElement(request::setRequestURI, this.requestUri); + request.setSecure("https".equals(this.scheme)); + return request; + } + + } + +} diff --git a/web/src/test/java/org/springframework/security/web/util/matcher/RegexRequestMatcherTests.java b/web/src/test/java/org/springframework/security/web/util/matcher/RegexRequestMatcherTests.java index 8263b776f0..9671303750 100644 --- a/web/src/test/java/org/springframework/security/web/util/matcher/RegexRequestMatcherTests.java +++ b/web/src/test/java/org/springframework/security/web/util/matcher/RegexRequestMatcherTests.java @@ -28,6 +28,7 @@ import org.springframework.mock.web.MockHttpServletRequest; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.BDDMockito.given; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; import static org.springframework.security.web.util.matcher.RegexRequestMatcher.regexMatcher; /** @@ -50,8 +51,7 @@ public class RegexRequestMatcherTests { @Test public void matchesIfHttpMethodAndPathMatch() { RegexRequestMatcher matcher = new RegexRequestMatcher(".*", "GET"); - MockHttpServletRequest request = new MockHttpServletRequest("GET", "/anything"); - request.setServletPath("/anything"); + MockHttpServletRequest request = get("/anything").build(); assertThat(matcher.matches(request)).isTrue(); }