From 98686a51390d3450c4d8c83573f59510643cd393 Mon Sep 17 00:00:00 2001 From: Josh Cummings <3627351+jzheaux@users.noreply.github.com> Date: Wed, 2 Jul 2025 18:16:41 -0600 Subject: [PATCH] Standardize Mock Request Paths Closes gh-17449 --- cas/spring-security-cas.gradle | 1 + .../cas/web/CasAuthenticationFilterTests.java | 36 +++--- .../config/FilterChainProxyConfigTests.java | 6 +- .../configurers/AuthorizeRequestsTests.java | 7 +- .../configurers/HttpSecurityLogoutTests.java | 11 +- .../HttpSecurityRequestMatchersTests.java | 59 +++++----- ...ttpSecuritySecurityMatchersNoMvcTests.java | 8 +- ...ionManagementConfigurerServlet31Tests.java | 23 +--- .../client/OAuth2LoginConfigurerTests.java | 52 ++++----- ...reshedEventListenerConfigurationTests.java | 4 +- .../OidcUserRefreshedEventListenerTests.java | 4 +- .../saml2/Saml2LoginConfigurerTests.java | 5 +- .../saml2/Saml2LogoutConfigurerTests.java | 4 +- ...tadataSourceBeanDefinitionParserTests.java | 4 +- .../Saml2LogoutBeanDefinitionParserTests.java | 3 +- ...SessionManagementConfigServlet31Tests.java | 34 ++---- .../CustomHttpSecurityConfigurerTests.java | 10 +- ...egistrationsBeanDefinitionParserTests.java | 5 +- .../web/AuthorizeRequestsDslTests.kt | 31 +---- .../annotation/web/RequiresChannelDslTests.kt | 10 +- etc/checkstyle/checkstyle.xml | 1 + .../spring-security-itest-context.gradle | 3 +- ...amespaceWithMultipleInterceptorsTests.java | 10 +- .../spring-security-oauth2-client.gradle | 1 + ...uth2AuthorizationRequestResolverTests.java | 106 +++++------------- ...uth2AuthorizationCodeGrantFilterTests.java | 10 +- ...thorizationRequestRedirectFilterTests.java | 37 +++--- .../OAuth2LoginAuthenticationFilterTests.java | 95 +++++++--------- ...ing-security-saml2-service-provider.gradle | 1 + ...aml4AuthenticationTokenConverterTests.java | 9 +- ...SamlAuthenticationTokenConverterTests.java | 9 +- ...ml4AuthenticationRequestResolverTests.java | 5 +- ...questValidatorParametersResolverTests.java | 9 +- ...questValidatorParametersResolverTests.java | 9 +- ...aml5AuthenticationTokenConverterTests.java | 9 +- ...ml5AuthenticationRequestResolverTests.java | 5 +- ...questValidatorParametersResolverTests.java | 9 +- ...tMatcherMetadataResponseResolverTests.java | 5 +- .../logout/Saml2LogoutRequestFilterTests.java | 34 +++--- .../Saml2LogoutResponseFilterTests.java | 27 ++--- ...rtyInitiatedLogoutSuccessHandlerTests.java | 7 +- .../security/web/FilterChainProxyTests.java | 4 +- .../security/web/FilterInvocationTests.java | 36 ++---- .../RequestMatcherRedirectFilterTests.java | 8 +- .../ExceptionTranslationFilterTests.java | 44 ++------ .../channel/ChannelProcessingFilterTests.java | 10 +- .../InsecureChannelProcessorTests.java | 22 ++-- .../channel/SecureChannelProcessorTests.java | 22 ++-- ...leEvaluationContextPostProcessorTests.java | 4 +- ...InvocationSecurityMetadataSourceTests.java | 22 ++-- .../FilterSecurityInterceptorTests.java | 4 +- ...ctAuthenticationProcessingFilterTests.java | 36 +++--- ...LoginUrlAuthenticationEntryPointTests.java | 61 +++------- ...ingAuthenticationManagerResolverTests.java | 4 +- ...namePasswordAuthenticationFilterTests.java | 9 +- .../logout/LogoutHandlerTests.java | 16 +-- .../ott/GenerateOneTimeTokenFilterTests.java | 8 +- ...eTokenSubmitPageGeneratingFilterTests.java | 14 +-- .../www/BasicAuthenticationFilterTests.java | 56 ++++----- .../www/DigestAuthenticationFilterTests.java | 4 +- .../security/web/debug/DebugFilterTests.java | 6 +- .../firewall/DefaultHttpFirewallTests.java | 4 +- .../web/firewall/StrictHttpFirewallTests.java | 4 +- .../matcher/RegexRequestMatcherTests.java | 4 +- 64 files changed, 399 insertions(+), 721 deletions(-) diff --git a/cas/spring-security-cas.gradle b/cas/spring-security-cas.gradle index cc5c13f604..fd4a614fa5 100644 --- a/cas/spring-security-cas.gradle +++ b/cas/spring-security-cas.gradle @@ -14,6 +14,7 @@ dependencies { provided 'jakarta.servlet:jakarta.servlet-api' + testImplementation project(path : ':spring-security-web', configuration : 'tests') testImplementation "org.assertj:assertj-core" testImplementation "org.junit.jupiter:junit-jupiter-api" testImplementation "org.junit.jupiter:junit-jupiter-params" diff --git a/cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationFilterTests.java b/cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationFilterTests.java index 296043527e..423c99cfe5 100644 --- a/cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationFilterTests.java +++ b/cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationFilterTests.java @@ -55,6 +55,8 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post; /** * Tests {@link CasAuthenticationFilter}. @@ -79,9 +81,7 @@ public class CasAuthenticationFilterTests { @Test public void testNormalOperation() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/login/cas"); - request.setServletPath("/login/cas"); - request.addParameter("ticket", "ST-0-ER94xMJmn6pha35CQRoZ"); + MockHttpServletRequest request = post("/login/cas").param("ticket", "ST-0-ER94xMJmn6pha35CQRoZ").build(); CasAuthenticationFilter filter = new CasAuthenticationFilter(); filter.setAuthenticationManager((a) -> a); assertThat(filter.requiresAuthentication(request, new MockHttpServletResponse())).isTrue(); @@ -104,24 +104,22 @@ public class CasAuthenticationFilterTests { String url = "/login/cas"; CasAuthenticationFilter filter = new CasAuthenticationFilter(); filter.setFilterProcessesUrl(url); - MockHttpServletRequest request = new MockHttpServletRequest("POST", url); + MockHttpServletRequest request = post(url).build(); MockHttpServletResponse response = new MockHttpServletResponse(); - request.setServletPath(url); assertThat(filter.requiresAuthentication(request, response)).isTrue(); } @Test public void testRequiresAuthenticationProxyRequest() { CasAuthenticationFilter filter = new CasAuthenticationFilter(); - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/pgtCallback").build(); MockHttpServletResponse response = new MockHttpServletResponse(); - request.setServletPath("/pgtCallback"); assertThat(filter.requiresAuthentication(request, response)).isFalse(); filter.setProxyReceptorUrl(request.getServletPath()); assertThat(filter.requiresAuthentication(request, response)).isFalse(); filter.setProxyGrantingTicketStorage(mock(ProxyGrantingTicketStorage.class)); assertThat(filter.requiresAuthentication(request, response)).isTrue(); - request.setServletPath("/other"); + request = get("/other").build(); assertThat(filter.requiresAuthentication(request, response)).isFalse(); } @@ -133,12 +131,10 @@ public class CasAuthenticationFilterTests { CasAuthenticationFilter filter = new CasAuthenticationFilter(); filter.setFilterProcessesUrl(url); filter.setServiceProperties(properties); - MockHttpServletRequest request = new MockHttpServletRequest("POST", url); + MockHttpServletRequest request = post(url).build(); MockHttpServletResponse response = new MockHttpServletResponse(); - request.setServletPath(url); assertThat(filter.requiresAuthentication(request, response)).isTrue(); - request = new MockHttpServletRequest("POST", "/other"); - request.setServletPath("/other"); + request = post("/other").build(); assertThat(filter.requiresAuthentication(request, response)).isFalse(); request.setParameter(properties.getArtifactParameter(), "value"); assertThat(filter.requiresAuthentication(request, response)).isTrue(); @@ -156,9 +152,8 @@ public class CasAuthenticationFilterTests { @Test public void testAuthenticateProxyUrl() throws Exception { CasAuthenticationFilter filter = new CasAuthenticationFilter(); - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/pgtCallback").build(); MockHttpServletResponse response = new MockHttpServletResponse(); - request.setServletPath("/pgtCallback"); filter.setProxyGrantingTicketStorage(mock(ProxyGrantingTicketStorage.class)); filter.setProxyReceptorUrl(request.getServletPath()); assertThat(filter.attemptAuthentication(request, response)).isNull(); @@ -172,9 +167,7 @@ public class CasAuthenticationFilterTests { given(manager.authenticate(any(Authentication.class))).willReturn(authentication); ServiceProperties serviceProperties = new ServiceProperties(); serviceProperties.setAuthenticateAllArtifacts(true); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/authenticate"); - request.setParameter("ticket", "ST-1-123"); - request.setServletPath("/authenticate"); + MockHttpServletRequest request = post("/authenticate").param("ticket", "ST-1-123").build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); CasAuthenticationFilter filter = new CasAuthenticationFilter(); @@ -200,10 +193,9 @@ public class CasAuthenticationFilterTests { @Test public void testChainNotInvokedForProxyReceptor() throws Exception { CasAuthenticationFilter filter = new CasAuthenticationFilter(); - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/pgtCallback").build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); - request.setServletPath("/pgtCallback"); filter.setProxyGrantingTicketStorage(mock(ProxyGrantingTicketStorage.class)); filter.setProxyReceptorUrl(request.getServletPath()); filter.doFilter(request, response, chain); @@ -271,16 +263,14 @@ public class CasAuthenticationFilterTests { @Test public void requiresAuthenticationWhenProxyRequestMatcherThenMatches() { CasAuthenticationFilter filter = new CasAuthenticationFilter(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", "/pgtCallback"); + MockHttpServletRequest request = get("/pgtCallback").build(); MockHttpServletResponse response = new MockHttpServletResponse(); - request.setServletPath("/pgtCallback"); assertThat(filter.requiresAuthentication(request, response)).isFalse(); filter.setProxyReceptorMatcher(PathPatternRequestMatcher.withDefaults().matcher(request.getServletPath())); assertThat(filter.requiresAuthentication(request, response)).isFalse(); filter.setProxyGrantingTicketStorage(mock(ProxyGrantingTicketStorage.class)); assertThat(filter.requiresAuthentication(request, response)).isTrue(); - request.setRequestURI("/other"); - request.setServletPath("/other"); + request = get("/other").build(); assertThat(filter.requiresAuthentication(request, response)).isFalse(); } diff --git a/config/src/test/java/org/springframework/security/config/FilterChainProxyConfigTests.java b/config/src/test/java/org/springframework/security/config/FilterChainProxyConfigTests.java index e2f81e3e17..4794bce61a 100644 --- a/config/src/test/java/org/springframework/security/config/FilterChainProxyConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/FilterChainProxyConfigTests.java @@ -44,6 +44,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests {@link FilterChainProxy}. @@ -143,13 +144,12 @@ public class FilterChainProxyConfigTests { } private void doNormalOperation(FilterChainProxy filterChainProxy) throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); - request.setServletPath("/foo/secure/super/somefile.html"); + MockHttpServletRequest request = get("/foo/secure/super/somefile.html").build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); filterChainProxy.doFilter(request, response, chain); verify(chain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); - request.setServletPath("/a/path/which/doesnt/match/any/filter.html"); + request = get("/a/path/which/doesnt/match/any/filter.html").build(); chain = mock(FilterChain.class); filterChainProxy.doFilter(request, response, chain); verify(chain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AuthorizeRequestsTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AuthorizeRequestsTests.java index 992dc0de1f..4ecbd91d48 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AuthorizeRequestsTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AuthorizeRequestsTests.java @@ -77,7 +77,6 @@ public class AuthorizeRequestsTests { public void setup() { this.servletContext = spy(MockServletContext.mvc()); this.request = new MockHttpServletRequest(this.servletContext, "GET", ""); - this.request.setMethod("GET"); this.response = new MockHttpServletResponse(); this.chain = new MockFilterChain(); } @@ -111,10 +110,12 @@ public class AuthorizeRequestsTests { public void antMatchersPathVariables() throws Exception { loadConfig(AntPatchersPathVariables.class); this.request.setServletPath("/user/user"); + this.request.setRequestURI("/user/user"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); this.setup(); this.request.setServletPath("/user/deny"); + this.request.setRequestURI("/user/deny"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); } @@ -123,10 +124,12 @@ public class AuthorizeRequestsTests { @Test public void antMatchersPathVariablesCaseInsensitive() throws Exception { loadConfig(AntPatchersPathVariables.class); + this.request.setRequestURI("/USER/user"); this.request.setServletPath("/USER/user"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); this.setup(); + this.request.setRequestURI("/USER/deny"); this.request.setServletPath("/USER/deny"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); @@ -137,10 +140,12 @@ public class AuthorizeRequestsTests { public void antMatchersPathVariablesCaseInsensitiveCamelCaseVariables() throws Exception { loadConfig(AntMatchersPathVariablesCamelCaseVariables.class); this.request.setServletPath("/USER/user"); + this.request.setRequestURI("/USER/user"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); this.setup(); this.request.setServletPath("/USER/deny"); + this.request.setRequestURI("/USER/deny"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityLogoutTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityLogoutTests.java index b82c2a57a3..98340ec271 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityLogoutTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityLogoutTests.java @@ -39,6 +39,7 @@ import org.springframework.security.web.context.HttpSessionSecurityContextReposi import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post; /** * @author Rob Winch @@ -48,8 +49,6 @@ public class HttpSecurityLogoutTests { AnnotationConfigWebApplicationContext context; - MockHttpServletRequest request; - MockHttpServletResponse response; MockFilterChain chain; @@ -59,7 +58,6 @@ public class HttpSecurityLogoutTests { @BeforeEach public void setup() { - this.request = new MockHttpServletRequest("GET", ""); this.response = new MockHttpServletResponse(); this.chain = new MockFilterChain(); } @@ -77,11 +75,10 @@ public class HttpSecurityLogoutTests { loadConfig(ClearAuthenticationFalseConfig.class); SecurityContext currentContext = SecurityContextHolder.createEmptyContext(); currentContext.setAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER")); - this.request.getSession() + MockHttpServletRequest request = post("/logout").build(); + request.getSession() .setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, currentContext); - this.request.setMethod("POST"); - this.request.setServletPath("/logout"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(currentContext.getAuthentication()).isNotNull(); } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityRequestMatchersTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityRequestMatchersTests.java index 85d7838988..ef85cab20e 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityRequestMatchersTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityRequestMatchersTests.java @@ -45,6 +45,7 @@ import org.springframework.web.servlet.handler.HandlerMappingIntrospector; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.security.config.Customizer.withDefaults; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * @author Rob Winch @@ -54,8 +55,6 @@ public class HttpSecurityRequestMatchersTests { AnnotationConfigWebApplicationContext context; - MockHttpServletRequest request; - MockHttpServletResponse response; MockFilterChain chain; @@ -65,8 +64,6 @@ public class HttpSecurityRequestMatchersTests { @BeforeEach public void setup() { - this.request = new MockHttpServletRequest("GET", ""); - this.request.setMethod("GET"); this.response = new MockHttpServletResponse(); this.chain = new MockFilterChain(); } @@ -87,70 +84,64 @@ public class HttpSecurityRequestMatchersTests { @Test public void requestMatchersMvcMatcherServletPath() throws Exception { loadConfig(RequestMatchersMvcMatcherServeltPathConfig.class); - this.request.setServletPath("/spring"); - this.request.setRequestURI("/spring/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + MockHttpServletRequest request = get().requestUri(null, "/spring", "/path").build(); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); setup(); - this.request.setServletPath(""); - this.request.setRequestURI("/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + request = get().requestUri(null, "", "/path").build(); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); setup(); - this.request.setServletPath("/other"); - this.request.setRequestURI("/other/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + request = get().requestUri(null, "/other", "/path").build(); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); } @Test public void requestMatcherWhensMvcMatcherServletPathInLambdaThenPathIsSecured() throws Exception { loadConfig(RequestMatchersMvcMatcherServletPathInLambdaConfig.class); - this.request.setServletPath("/spring"); - this.request.setRequestURI("/spring/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + MockHttpServletRequest request = get().requestUri(null, "/spring", "/path").build(); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); setup(); - this.request.setServletPath(""); - this.request.setRequestURI("/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + request = get().requestUri(null, "", "/path").build(); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); setup(); - this.request.setServletPath("/other"); - this.request.setRequestURI("/other/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + request = get().requestUri(null, "/other", "/path").build(); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); } @Test public void requestMatcherWhenMultiMvcMatcherInLambdaThenAllPathsAreDenied() throws Exception { loadConfig(MultiMvcMatcherInLambdaConfig.class); - this.request.setRequestURI("/test-1"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + MockHttpServletRequest request = get("/test-1").build(); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); setup(); - this.request.setRequestURI("/test-2"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + request = get("/test-2").build(); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); setup(); - this.request.setRequestURI("/test-3"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + request = get("/test-3").build(); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); } @Test public void requestMatcherWhenMultiMvcMatcherThenAllPathsAreDenied() throws Exception { loadConfig(MultiMvcMatcherConfig.class); - this.request.setRequestURI("/test-1"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + MockHttpServletRequest request = get("/test-1").build(); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); setup(); - this.request.setRequestURI("/test-2"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + request = get("/test-2").build(); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); setup(); - this.request.setRequestURI("/test-3"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + request = get("/test-3").build(); + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecuritySecurityMatchersNoMvcTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecuritySecurityMatchersNoMvcTests.java index c6afdf3572..0b3cba058c 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecuritySecurityMatchersNoMvcTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecuritySecurityMatchersNoMvcTests.java @@ -67,7 +67,7 @@ public class HttpSecuritySecurityMatchersNoMvcTests { @BeforeEach public void setup() throws Exception { - this.request = new MockHttpServletRequest("GET", ""); + this.request = new MockHttpServletRequest(); this.request.setMethod("GET"); this.response = new MockHttpServletResponse(); this.chain = new MockFilterChain(); @@ -83,15 +83,15 @@ public class HttpSecuritySecurityMatchersNoMvcTests { @Test public void securityMatcherWhenNoMvcThenAntMatcher() throws Exception { loadConfig(SecurityMatcherNoMvcConfig.class); - this.request.setServletPath("/path"); + this.request.setRequestURI("/path"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); setup(); - this.request.setServletPath("/path.html"); + this.request.setRequestURI("/path.html"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); setup(); - this.request.setServletPath("/path/"); + this.request.setRequestURI("/path/"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); List requestMatchers = this.springSecurityFilterChain.getFilterChains() .stream() diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java index cd32887961..98998e31e8 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java @@ -30,14 +30,10 @@ import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; -import org.springframework.security.core.Authentication; -import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.core.userdetails.PasswordEncodedUser; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.provisioning.InMemoryUserDetailsManager; import org.springframework.security.web.SecurityFilterChain; -import org.springframework.security.web.context.HttpRequestResponseHolder; -import org.springframework.security.web.context.HttpSessionSecurityContextRepository; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.CsrfTokenRequestHandler; import org.springframework.security.web.csrf.DeferredCsrfToken; @@ -46,14 +42,13 @@ import org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.security.config.Customizer.withDefaults; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post; /** * @author Rob Winch */ public class SessionManagementConfigurerServlet31Tests { - MockHttpServletRequest request; - MockHttpServletResponse response; MockFilterChain chain; @@ -64,7 +59,6 @@ public class SessionManagementConfigurerServlet31Tests { @BeforeEach public void setup() { - this.request = new MockHttpServletRequest("GET", ""); this.response = new MockHttpServletResponse(); this.chain = new MockFilterChain(); } @@ -78,13 +72,9 @@ public class SessionManagementConfigurerServlet31Tests { @Test public void changeSessionIdThenPreserveParameters() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); + MockHttpServletRequest request = post("/login").param("username", "user").param("password", "password").build(); String id = request.getSession().getId(); request.getSession(); - request.setServletPath("/login"); - request.setMethod("POST"); - request.setParameter("username", "user"); - request.setParameter("password", "password"); HttpSessionCsrfTokenRepository repository = new HttpSessionCsrfTokenRepository(); CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler(); DeferredCsrfToken deferredCsrfToken = repository.loadDeferredToken(request, this.response); @@ -106,15 +96,6 @@ public class SessionManagementConfigurerServlet31Tests { this.springSecurityFilterChain = this.context.getBean("springSecurityFilterChain", Filter.class); } - private void login(Authentication auth) { - HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository(); - HttpRequestResponseHolder requestResponseHolder = new HttpRequestResponseHolder(this.request, this.response); - repo.loadContext(requestResponseHolder); - SecurityContextImpl securityContextImpl = new SecurityContextImpl(); - securityContextImpl.setAuthentication(auth); - repo.saveContext(securityContextImpl, requestResponseHolder.getRequest(), requestResponseHolder.getResponse()); - } - @Configuration @EnableWebSecurity static class SessionManagementDefaultSessionFixationServlet31Config { diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java index cc1a30a381..d24fc4f723 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java @@ -107,6 +107,7 @@ import org.springframework.security.web.authentication.HttpStatusEntryPoint; import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.HttpSessionSecurityContextRepository; import org.springframework.security.web.context.SecurityContextRepository; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import org.springframework.security.web.session.HttpSessionDestroyedEvent; import org.springframework.security.web.util.matcher.RequestHeaderRequestMatcher; import org.springframework.test.util.ReflectionTestUtils; @@ -127,6 +128,7 @@ import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.config.annotation.SecurityContextChangedListenerArgumentMatchers.setAuthentication; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; @@ -185,8 +187,7 @@ public class OAuth2LoginConfigurerTests { @BeforeEach public void setup() { - this.request = new MockHttpServletRequest("GET", ""); - this.request.setServletPath("/login/oauth2/code/google"); + this.request = TestMockHttpServletRequests.get("/login/oauth2/code/google").build(); this.response = new MockHttpServletResponse(); this.filterChain = new MockFilterChain(); } @@ -347,7 +348,7 @@ public class OAuth2LoginConfigurerTests { loadConfig(OAuth2LoginConfigLoginProcessingUrl.class); // setup authorization request OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest(); - this.request.setServletPath("/login/oauth2/google"); + this.request.setRequestURI("/login/oauth2/google"); this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, this.request, this.response); // setup authentication parameters this.request.setParameter("code", "code123"); @@ -381,8 +382,7 @@ public class OAuth2LoginConfigurerTests { // @formatter:on given(resolver.resolve(any())).willReturn(result); String requestUri = "/oauth2/authorization/google"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = TestMockHttpServletRequests.get(requestUri).build(); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getRedirectedUrl()).isEqualTo( "https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1"); @@ -394,8 +394,7 @@ public class OAuth2LoginConfigurerTests { // @formatter:off // @formatter:on String requestUri = "/oauth2/authorization/google"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = TestMockHttpServletRequests.get(requestUri).build(); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getRedirectedUrl()).isEqualTo( "https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1"); @@ -418,8 +417,7 @@ public class OAuth2LoginConfigurerTests { // @formatter:on given(resolver.resolve(any())).willReturn(result); String requestUri = "/oauth2/authorization/google"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = TestMockHttpServletRequests.get(requestUri).build(); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getRedirectedUrl()).isEqualTo( "https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1"); @@ -432,8 +430,7 @@ public class OAuth2LoginConfigurerTests { RedirectStrategy redirectStrategy = this.context .getBean(OAuth2LoginConfigCustomAuthorizationRedirectStrategy.class).redirectStrategy; String requestUri = "/oauth2/authorization/google"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = get(requestUri).build(); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); then(redirectStrategy).should().sendRedirect(any(), any(), anyString()); } @@ -445,8 +442,7 @@ public class OAuth2LoginConfigurerTests { RedirectStrategy redirectStrategy = this.context .getBean(OAuth2LoginConfigCustomAuthorizationRedirectStrategyInLambda.class).redirectStrategy; String requestUri = "/oauth2/authorization/google"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = get(requestUri).build(); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); then(redirectStrategy).should().sendRedirect(any(), any(), anyString()); } @@ -456,8 +452,7 @@ public class OAuth2LoginConfigurerTests { public void oauth2LoginWithOneClientConfiguredThenRedirectForAuthorization() throws Exception { loadConfig(OAuth2LoginConfig.class); String requestUri = "/"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = get(requestUri).build(); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getRedirectedUrl()).matches("http://localhost/oauth2/authorization/google"); } @@ -467,8 +462,7 @@ public class OAuth2LoginConfigurerTests { public void oauth2LoginWithOneClientConfiguredAndFormLoginThenRedirectDefaultLoginPage() throws Exception { loadConfig(OAuth2LoginConfigFormLogin.class); String requestUri = "/"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = get(requestUri).build(); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getRedirectedUrl()).matches("http://localhost/login"); } @@ -479,8 +473,7 @@ public class OAuth2LoginConfigurerTests { throws Exception { loadConfig(OAuth2LoginConfig.class); String requestUri = "/favicon.ico"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = get(requestUri).build(); this.request.addHeader(HttpHeaders.ACCEPT, new MediaType("image", "*").toString()); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getRedirectedUrl()).matches("http://localhost/login"); @@ -491,8 +484,7 @@ public class OAuth2LoginConfigurerTests { public void oauth2LoginWithMultipleClientsConfiguredThenRedirectDefaultLoginPage() throws Exception { loadConfig(OAuth2LoginConfigMultipleClients.class); String requestUri = "/"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = get(requestUri).build(); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getRedirectedUrl()).matches("http://localhost/login"); } @@ -503,8 +495,7 @@ public class OAuth2LoginConfigurerTests { throws Exception { loadConfig(OAuth2LoginConfig.class); String requestUri = "/"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = get(requestUri).build(); this.request.addHeader("X-Requested-With", "XMLHttpRequest"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getRedirectedUrl()).doesNotMatch("http://localhost/oauth2/authorization/google"); @@ -515,8 +506,7 @@ public class OAuth2LoginConfigurerTests { throws Exception { loadConfig(OAuth2LoginWithHttpBasicConfig.class); String requestUri = "/"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = get(requestUri).build(); this.request.addHeader("X-Requested-With", "XMLHttpRequest"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getStatus()).isEqualTo(401); @@ -527,8 +517,7 @@ public class OAuth2LoginConfigurerTests { throws Exception { loadConfig(OAuth2LoginWithXHREntryPointConfig.class); String requestUri = "/"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = get(requestUri).build(); this.request.addHeader("X-Requested-With", "XMLHttpRequest"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getStatus()).isEqualTo(401); @@ -540,8 +529,7 @@ public class OAuth2LoginConfigurerTests { throws Exception { loadConfig(OAuth2LoginConfigAuthorizationCodeClientAndOtherClients.class); String requestUri = "/"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = get(requestUri).build(); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getRedirectedUrl()).matches("http://localhost/oauth2/authorization/google"); } @@ -550,8 +538,7 @@ public class OAuth2LoginConfigurerTests { public void oauth2LoginWithCustomLoginPageThenRedirectCustomLoginPage() throws Exception { loadConfig(OAuth2LoginConfigCustomLoginPage.class); String requestUri = "/"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = get(requestUri).build(); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getRedirectedUrl()).matches("http://localhost/custom-login"); } @@ -560,8 +547,7 @@ public class OAuth2LoginConfigurerTests { public void requestWhenOauth2LoginWithCustomLoginPageInLambdaThenRedirectCustomLoginPage() throws Exception { loadConfig(OAuth2LoginConfigCustomLoginPageInLambda.class); String requestUri = "/"; - this.request = new MockHttpServletRequest("GET", requestUri); - this.request.setServletPath(requestUri); + this.request = get(requestUri).build(); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getRedirectedUrl()).matches("http://localhost/custom-login"); } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerConfigurationTests.java index df28e06209..30907b5e8d 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerConfigurationTests.java @@ -89,6 +89,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests for {@link OidcUserRefreshedEventListener} with {@link OAuth2LoginConfigurer}. @@ -147,8 +148,7 @@ public class OidcUserRefreshedEventListenerConfigurationTests { @BeforeEach public void setUp() { - this.request = new MockHttpServletRequest("GET", ""); - this.request.setServletPath("/"); + this.request = get("/").build(); this.response = new MockHttpServletResponse(); RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(this.request, this.response)); } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerTests.java index 6b8f82a8bd..84c7a3e42b 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcUserRefreshedEventListenerTests.java @@ -42,6 +42,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests for {@link OidcUserRefreshedEventListener}. @@ -64,8 +65,7 @@ public class OidcUserRefreshedEventListenerTests { this.eventListener = new OidcUserRefreshedEventListener(); this.eventListener.setSecurityContextRepository(this.securityContextRepository); - this.request = new MockHttpServletRequest("GET", ""); - this.request.setServletPath("/"); + this.request = get("/").build(); this.response = new MockHttpServletResponse(); } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java index 0458ced521..84eafcf05d 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java @@ -94,6 +94,7 @@ import org.springframework.security.web.authentication.AuthenticationFailureHand import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.HttpSessionSecurityContextRepository; import org.springframework.security.web.context.SecurityContextRepository; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; @@ -190,8 +191,7 @@ public class Saml2LoginConfigurerTests { @BeforeEach public void setup() { - this.request = new MockHttpServletRequest("POST", ""); - this.request.setServletPath("/login/saml2/sso/test-rp"); + this.request = TestMockHttpServletRequests.post("/login/saml2/sso/test-rp").build(); this.response = new MockHttpServletResponse(); this.filterChain = new MockFilterChain(); } @@ -430,7 +430,6 @@ public class Saml2LoginConfigurerTests { private void performSaml2Login(String expected) throws IOException, ServletException { // setup authentication parameters this.request.setRequestURI("/login/saml2/sso/registration-id"); - this.request.setServletPath("/login/saml2/sso/registration-id"); this.request.setParameter("SAMLResponse", Base64.getEncoder().encodeToString("saml2-xml-response-object".getBytes())); // perform test diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurerTests.java index b9d4deec08..da8912c6ba 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurerTests.java @@ -76,6 +76,7 @@ import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.authentication.logout.LogoutFilter; import org.springframework.security.web.authentication.logout.LogoutHandler; import org.springframework.security.web.authentication.logout.LogoutSuccessHandler; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; @@ -158,8 +159,7 @@ public class Saml2LogoutConfigurerTests { Collections.emptyMap()); principal.setRelyingPartyRegistrationId("registration-id"); this.user = new Saml2Authentication(principal, "response", AuthorityUtils.createAuthorityList("ROLE_USER")); - this.request = new MockHttpServletRequest("POST", ""); - this.request.setServletPath("/login/saml2/sso/test-rp"); + this.request = TestMockHttpServletRequests.post("/login/saml2/sso/test-rp").build(); this.response = new MockHttpServletResponse(); } diff --git a/config/src/test/java/org/springframework/security/config/http/FilterSecurityMetadataSourceBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/http/FilterSecurityMetadataSourceBeanDefinitionParserTests.java index 21972bac10..2567471afd 100644 --- a/config/src/test/java/org/springframework/security/config/http/FilterSecurityMetadataSourceBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/http/FilterSecurityMetadataSourceBeanDefinitionParserTests.java @@ -132,9 +132,7 @@ public class FilterSecurityMetadataSourceBeanDefinitionParserTests { } private FilterInvocation createFilterInvocation(String path, String method) { - MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); - request.setRequestURI(path); - request.setMethod(method); + MockHttpServletRequest request = new MockHttpServletRequest(method, path); return new FilterInvocation(request, new MockHttpServletResponse(), new MockFilterChain()); } diff --git a/config/src/test/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParserTests.java index 4045c1b689..71da35955b 100644 --- a/config/src/test/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParserTests.java @@ -134,8 +134,7 @@ public class Saml2LogoutBeanDefinitionParserTests { principal.setRelyingPartyRegistrationId("registration-id"); this.saml2User = new Saml2Authentication(principal, "response", AuthorityUtils.createAuthorityList("ROLE_USER")); - this.request = new MockHttpServletRequest("POST", ""); - this.request.setServletPath("/login/saml2/sso/test-rp"); + this.request = new MockHttpServletRequest("POST", "/login/saml2/sso/test-rp"); this.response = new MockHttpServletResponse(); } diff --git a/config/src/test/java/org/springframework/security/config/http/SessionManagementConfigServlet31Tests.java b/config/src/test/java/org/springframework/security/config/http/SessionManagementConfigServlet31Tests.java index 03c50c9db9..333c2db61f 100644 --- a/config/src/test/java/org/springframework/security/config/http/SessionManagementConfigServlet31Tests.java +++ b/config/src/test/java/org/springframework/security/config/http/SessionManagementConfigServlet31Tests.java @@ -26,10 +26,7 @@ import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.config.util.InMemoryXmlApplicationContext; -import org.springframework.security.core.Authentication; -import org.springframework.security.core.context.SecurityContextImpl; -import org.springframework.security.web.context.HttpRequestResponseHolder; -import org.springframework.security.web.context.HttpSessionSecurityContextRepository; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import static org.assertj.core.api.Assertions.assertThat; @@ -61,7 +58,7 @@ public class SessionManagementConfigServlet31Tests { @BeforeEach public void setup() { - this.request = new MockHttpServletRequest("GET", ""); + this.request = new MockHttpServletRequest(); this.response = new MockHttpServletResponse(); this.chain = new MockFilterChain(); } @@ -75,12 +72,11 @@ public class SessionManagementConfigServlet31Tests { @Test public void changeSessionIdThenPreserveParameters() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); + MockHttpServletRequest request = TestMockHttpServletRequests.post("/login") + .param("username", "user") + .param("password", "password") + .build(); request.getSession(); - request.setServletPath("/login"); - request.setMethod("POST"); - request.setParameter("username", "user"); - request.setParameter("password", "password"); request.getSession().setAttribute("attribute1", "value1"); String id = request.getSession().getId(); // @formatter:off @@ -99,12 +95,11 @@ public class SessionManagementConfigServlet31Tests { @Test public void changeSessionId() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); + MockHttpServletRequest request = TestMockHttpServletRequests.post("/login") + .param("username", "user") + .param("password", "password") + .build(); request.getSession(); - request.setServletPath("/login"); - request.setMethod("POST"); - request.setParameter("username", "user"); - request.setParameter("password", "password"); String id = request.getSession().getId(); // @formatter:off loadContext("\n" @@ -124,13 +119,4 @@ public class SessionManagementConfigServlet31Tests { this.springSecurityFilterChain = this.context.getBean("springSecurityFilterChain", Filter.class); } - private void login(Authentication auth) { - HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository(); - HttpRequestResponseHolder requestResponseHolder = new HttpRequestResponseHolder(this.request, this.response); - repo.loadContext(requestResponseHolder); - SecurityContextImpl securityContextImpl = new SecurityContextImpl(); - securityContextImpl.setAuthentication(auth); - repo.saveContext(securityContextImpl, requestResponseHolder.getRequest(), requestResponseHolder.getResponse()); - } - } diff --git a/config/src/test/java/org/springframework/security/config/http/customconfigurer/CustomHttpSecurityConfigurerTests.java b/config/src/test/java/org/springframework/security/config/http/customconfigurer/CustomHttpSecurityConfigurerTests.java index c47bf68b6e..01d2f83c61 100644 --- a/config/src/test/java/org/springframework/security/config/http/customconfigurer/CustomHttpSecurityConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/http/customconfigurer/CustomHttpSecurityConfigurerTests.java @@ -60,7 +60,7 @@ public class CustomHttpSecurityConfigurerTests { @BeforeEach public void setup() { - this.request = new MockHttpServletRequest("GET", ""); + this.request = new MockHttpServletRequest(); this.response = new MockHttpServletResponse(); this.chain = new MockFilterChain(); this.request.setMethod("GET"); @@ -76,7 +76,7 @@ public class CustomHttpSecurityConfigurerTests { @Test public void customConfiguerPermitAll() throws Exception { loadContext(Config.class); - this.request.setPathInfo("/public/something"); + this.request.setRequestURI("/public/something"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); } @@ -84,7 +84,7 @@ public class CustomHttpSecurityConfigurerTests { @Test public void customConfiguerFormLogin() throws Exception { loadContext(Config.class); - this.request.setPathInfo("/requires-authentication"); + this.request.setRequestURI("/requires-authentication"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(this.response.getRedirectedUrl()).endsWith("/custom"); } @@ -92,7 +92,7 @@ public class CustomHttpSecurityConfigurerTests { @Test public void customConfiguerCustomizeDisablesCsrf() throws Exception { loadContext(ConfigCustomize.class); - this.request.setPathInfo("/public/something"); + this.request.setRequestURI("/public/something"); this.request.setMethod("POST"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); @@ -101,7 +101,7 @@ public class CustomHttpSecurityConfigurerTests { @Test public void customConfiguerCustomizeFormLogin() throws Exception { loadContext(ConfigCustomize.class); - this.request.setPathInfo("/requires-authentication"); + this.request.setRequestURI("/requires-authentication"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(this.response.getRedirectedUrl()).endsWith("/other"); } diff --git a/config/src/test/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParserTests.java index c30593099a..2945ec8d17 100644 --- a/config/src/test/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParserTests.java @@ -41,6 +41,7 @@ import org.springframework.security.saml2.provider.service.web.authentication.Op import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.verify; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests for {@link RelyingPartyRegistrationsBeanDefinitionParser}. @@ -280,9 +281,7 @@ public class RelyingPartyRegistrationsBeanDefinitionParserTests { Converter relayStateResolver = this.spring.getContext().getBean(Converter.class); OpenSaml4AuthenticationRequestResolver authenticationRequestResolver = this.spring.getContext() .getBean(OpenSaml4AuthenticationRequestResolver.class); - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setRequestURI("/saml2/authenticate/one"); - request.setServletPath("/saml2/authenticate/one"); + MockHttpServletRequest request = get("/saml2/authenticate/one").build(); authenticationRequestResolver.resolve(request); verify(relayStateResolver).convert(request); } diff --git a/config/src/test/kotlin/org/springframework/security/config/annotation/web/AuthorizeRequestsDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/annotation/web/AuthorizeRequestsDslTests.kt index 42eda2bbfc..ed70c24915 100644 --- a/config/src/test/kotlin/org/springframework/security/config/annotation/web/AuthorizeRequestsDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/annotation/web/AuthorizeRequestsDslTests.kt @@ -44,8 +44,6 @@ import org.springframework.web.bind.annotation.PathVariable import org.springframework.web.bind.annotation.RequestMapping import org.springframework.web.bind.annotation.RestController import org.springframework.web.servlet.config.annotation.EnableWebMvc -import org.springframework.web.servlet.config.annotation.PathMatchConfigurer -import org.springframework.web.servlet.config.annotation.WebMvcConfigurer /** * Tests for [AuthorizeRequestsDsl] @@ -405,17 +403,11 @@ class AuthorizeRequestsDslTests { this.spring.register(MvcMatcherServletPathConfig::class.java).autowire() this.mockMvc.perform(MockMvcRequestBuilders.get("/spring/path") - .with { request -> - request.servletPath = "/spring" - request - }) + .servletPath("/spring")) .andExpect(status().isForbidden) this.mockMvc.perform(MockMvcRequestBuilders.get("/other/path") - .with { request -> - request.servletPath = "/other" - request - }) + .servletPath("/other")) .andExpect(status().isOk) } @@ -514,28 +506,15 @@ class AuthorizeRequestsDslTests { this.spring.register(MvcMatcherServletPathConfig::class.java).autowire() this.mockMvc.perform(MockMvcRequestBuilders.get("/spring/path") - .with { request -> - request.apply { - servletPath = "/spring" - } - }) + .servletPath("/spring")) .andExpect(status().isForbidden) this.mockMvc.perform(MockMvcRequestBuilders.put("/spring/path") - .with { request -> - request.apply { - servletPath = "/spring" - csrf() - } - }) + .servletPath("/spring")) .andExpect(status().isForbidden) this.mockMvc.perform(MockMvcRequestBuilders.get("/other/path") - .with { request -> - request.apply { - servletPath = "/other" - } - }) + .servletPath("/other")) .andExpect(status().isOk) } } diff --git a/config/src/test/kotlin/org/springframework/security/config/annotation/web/RequiresChannelDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/annotation/web/RequiresChannelDslTests.kt index 28632c1d92..439e64b663 100644 --- a/config/src/test/kotlin/org/springframework/security/config/annotation/web/RequiresChannelDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/annotation/web/RequiresChannelDslTests.kt @@ -83,18 +83,12 @@ class RequiresChannelDslTests { this.spring.register(MvcMatcherServletPathConfig::class.java).autowire() this.mockMvc.perform(MockMvcRequestBuilders.get("/spring/path") - .with { request -> - request.servletPath = "/spring" - request - }) + .servletPath("/spring")) .andExpect(status().isFound) .andExpect(redirectedUrl("https://localhost/spring/path")) this.mockMvc.perform(MockMvcRequestBuilders.get("/other/path") - .with { request -> - request.servletPath = "/other" - request - }) + .servletPath("/other")) .andExpect(MockMvcResultMatchers.status().isOk) } diff --git a/etc/checkstyle/checkstyle.xml b/etc/checkstyle/checkstyle.xml index 9b4b616eb0..04453420b3 100644 --- a/etc/checkstyle/checkstyle.xml +++ b/etc/checkstyle/checkstyle.xml @@ -18,6 +18,7 @@ + diff --git a/itest/context/spring-security-itest-context.gradle b/itest/context/spring-security-itest-context.gradle index c278418f74..15c4b52dbd 100644 --- a/itest/context/spring-security-itest-context.gradle +++ b/itest/context/spring-security-itest-context.gradle @@ -9,7 +9,8 @@ dependencies { implementation 'org.springframework:spring-context' implementation 'org.springframework:spring-tx' - testImplementation project(':spring-security-web') + testImplementation project(path: ':spring-security-web') + testImplementation project(path: ':spring-security-web', configuration: 'tests') testImplementation 'jakarta.servlet:jakarta.servlet-api' testImplementation 'org.springframework:spring-web' testImplementation "org.assertj:assertj-core" diff --git a/itest/context/src/integration-test/java/org/springframework/security/integration/HttpNamespaceWithMultipleInterceptorsTests.java b/itest/context/src/integration-test/java/org/springframework/security/integration/HttpNamespaceWithMultipleInterceptorsTests.java index aab6742b81..cfdeed656b 100644 --- a/itest/context/src/integration-test/java/org/springframework/security/integration/HttpNamespaceWithMultipleInterceptorsTests.java +++ b/itest/context/src/integration-test/java/org/springframework/security/integration/HttpNamespaceWithMultipleInterceptorsTests.java @@ -29,6 +29,7 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.context.HttpSessionSecurityContextRepository; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit.jupiter.SpringExtension; @@ -43,9 +44,7 @@ public class HttpNamespaceWithMultipleInterceptorsTests { @Test public void requestThatIsMatchedByDefaultInterceptorIsAllowed() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setMethod("GET"); - request.setServletPath("/somefile.html"); + MockHttpServletRequest request = TestMockHttpServletRequests.get("/somefile.html").build(); request.setSession(createAuthenticatedSession("ROLE_0", "ROLE_1", "ROLE_2")); MockHttpServletResponse response = new MockHttpServletResponse(); this.fcp.doFilter(request, response, new MockFilterChain()); @@ -54,10 +53,7 @@ public class HttpNamespaceWithMultipleInterceptorsTests { @Test public void securedUrlAccessIsRejectedWithoutRequiredRole() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setMethod("GET"); - - request.setServletPath("/secure/somefile.html"); + MockHttpServletRequest request = TestMockHttpServletRequests.get("/secure/somefile.html").build(); request.setSession(createAuthenticatedSession("ROLE_0")); MockHttpServletResponse response = new MockHttpServletResponse(); this.fcp.doFilter(request, response, new MockFilterChain()); diff --git a/oauth2/oauth2-client/spring-security-oauth2-client.gradle b/oauth2/oauth2-client/spring-security-oauth2-client.gradle index 93bab342cf..11b6c91f0a 100644 --- a/oauth2/oauth2-client/spring-security-oauth2-client.gradle +++ b/oauth2/oauth2-client/spring-security-oauth2-client.gradle @@ -18,6 +18,7 @@ dependencies { testImplementation project(path: ':spring-security-oauth2-core', configuration: 'tests') testImplementation project(path: ':spring-security-oauth2-jose', configuration: 'tests') + testImplementation project(path: ':spring-security-web', configuration: 'tests') testImplementation 'com.squareup.okhttp3:mockwebserver' testImplementation 'io.micrometer:context-propagation' testImplementation 'io.projectreactor.netty:reactor-netty' diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java index 1382a0368f..92a170394a 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java @@ -44,6 +44,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.entry; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post; /** * Tests for {@link DefaultOAuth2AuthorizationRequestResolver}. @@ -123,8 +125,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { @Test public void resolveWhenNotAuthorizationRequestThenDoesNotResolve() { String requestUri = "/path"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest).isNull(); } @@ -133,7 +134,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { @Test public void resolveWhenNotAuthorizationRequestThenRequestBodyNotConsumed() throws IOException { String requestUri = "/path"; - MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); + MockHttpServletRequest request = post(requestUri).build(); request.setContent("foo".getBytes(StandardCharsets.UTF_8)); request.setCharacterEncoding(StandardCharsets.UTF_8.name()); HttpServletRequest spyRequest = Mockito.spy(request); @@ -151,8 +152,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { ClientRegistration clientRegistration = this.registration1; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId() + "-invalid"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); // @formatter:off assertThatIllegalArgumentException() .isThrownBy(() -> this.resolver.resolve(request)) @@ -164,8 +164,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestWithValidClientThenResolves() { ClientRegistration clientRegistration = this.registration1; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest).isNotNull(); assertThat(authorizationRequest.getAuthorizationUri()) @@ -191,8 +190,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenClientAuthorizationRequiredExceptionAvailableThenResolves() { ClientRegistration clientRegistration = this.registration2; String requestUri = "/path"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request, clientRegistration.getRegistrationId()); assertThat(authorizationRequest).isNotNull(); @@ -204,8 +202,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenRedirectUriExpanded() { ClientRegistration clientRegistration = this.registration2; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); assertThat(authorizationRequest.getRedirectUri()) @@ -216,9 +213,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenHttpRedirectUriWithExtraVarsExpanded() { ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServerPort(8080); - request.setServletPath(requestUri); + MockHttpServletRequest request = get("localhost:8080" + requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); assertThat(authorizationRequest.getRedirectUri()) @@ -229,10 +224,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenHttpsRedirectUriWithExtraVarsExpanded() { ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setScheme("https"); - request.setServerPort(8081); - request.setServletPath(requestUri); + MockHttpServletRequest request = get("https://localhost:8081" + requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); assertThat(authorizationRequest.getRedirectUri()) @@ -243,10 +235,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestIncludesPort80ThenExpandedRedirectUriWithExtraVarsExcludesPort() { ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setScheme("http"); - request.setServerPort(80); - request.setServletPath(requestUri); + MockHttpServletRequest request = get("http://localhost" + requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); assertThat(authorizationRequest.getRedirectUri()) @@ -257,10 +246,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestIncludesPort443ThenExpandedRedirectUriWithExtraVarsExcludesPort() { ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setScheme("https"); - request.setServerPort(443); - request.setServletPath(requestUri); + MockHttpServletRequest request = get("https://localhost:443" + requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); assertThat(authorizationRequest.getRedirectUri()) @@ -271,10 +257,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestHasNoPortThenInvalidUrlException() { ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setScheme("https"); - request.setServerPort(-1); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).port(-1).build(); assertThatExceptionOfType(InvalidUrlException.class).isThrownBy(() -> this.resolver.resolve(request)); } @@ -283,9 +266,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenRedirectUriExpandedExcludesQueryString() { ClientRegistration clientRegistration = this.registration2; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.setQueryString("foo=bar"); + MockHttpServletRequest request = get(requestUri + "?foo=bar").build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); assertThat(authorizationRequest.getRedirectUri()) @@ -296,11 +277,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestIncludesPort80ThenExpandedRedirectUriExcludesPort() { ClientRegistration clientRegistration = this.registration1; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setScheme("http"); - request.setServerName("localhost"); - request.setServerPort(80); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getAuthorizationRequestUri()) .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" @@ -312,11 +289,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestIncludesPort443ThenExpandedRedirectUriExcludesPort() { ClientRegistration clientRegistration = this.registration1; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setScheme("https"); - request.setServerName("example.com"); - request.setServerPort(443); - request.setServletPath(requestUri); + MockHttpServletRequest request = get("https://example.com:443" + requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getAuthorizationRequestUri()) .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" @@ -328,8 +301,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenClientAuthorizationRequiredExceptionAvailableThenRedirectUriIsAuthorize() { ClientRegistration clientRegistration = this.registration1; String requestUri = "/path"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request, clientRegistration.getRegistrationId()); assertThat(authorizationRequest.getAuthorizationRequestUri()) @@ -342,8 +314,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestOAuth2LoginThenRedirectUriIsLogin() { ClientRegistration clientRegistration = this.registration2; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getAuthorizationRequestUri()) .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id-2&" @@ -355,9 +326,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestHasActionParameterAuthorizeThenRedirectUriIsAuthorize() { ClientRegistration clientRegistration = this.registration1; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.addParameter("action", "authorize"); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).param("action", "authorize").build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getAuthorizationRequestUri()) .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" @@ -369,9 +338,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestHasActionParameterLoginThenRedirectUriIsLogin() { ClientRegistration clientRegistration = this.registration2; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.addParameter("action", "login"); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).param("action", "login").build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getAuthorizationRequestUri()) .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id-2&" @@ -383,8 +350,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestWithValidPublicClientThenResolves() { ClientRegistration clientRegistration = this.publicClientRegistration; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest).isNotNull(); assertThat(authorizationRequest.getAuthorizationUri()) @@ -420,15 +386,13 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { ClientRegistration clientRegistration = this.registration1; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertPkceApplied(authorizationRequest, clientRegistration); clientRegistration = this.registration2; requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + request = get(requestUri).build(); authorizationRequest = this.resolver.resolve(request); assertPkceApplied(authorizationRequest, clientRegistration); } @@ -447,15 +411,13 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { ClientRegistration clientRegistration = this.registration1; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertPkceApplied(authorizationRequest, clientRegistration); clientRegistration = this.registration2; requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + request = get(requestUri).build(); authorizationRequest = this.resolver.resolve(request); assertPkceNotApplied(authorizationRequest, clientRegistration); } @@ -491,8 +453,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthenticationRequestWithValidOidcClientThenResolves() { ClientRegistration clientRegistration = this.oidcRegistration; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest).isNotNull(); assertThat(authorizationRequest.getAuthorizationUri()) @@ -524,8 +485,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestCustomizerRemovesNonceThenQueryExcludesNonce() { ClientRegistration clientRegistration = this.oidcRegistration; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); this.resolver.setAuthorizationRequestCustomizer( (builder) -> builder.additionalParameters((params) -> params.remove(OidcParameterNames.NONCE)) .attributes((attrs) -> attrs.remove(OidcParameterNames.NONCE))); @@ -543,8 +503,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestCustomizerAddsParameterThenQueryIncludesParameter() { ClientRegistration clientRegistration = this.oidcRegistration; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); this.resolver.setAuthorizationRequestCustomizer((builder) -> builder.authorizationRequestUri((uriBuilder) -> { uriBuilder.queryParam("param1", "value1"); return uriBuilder.build(); @@ -561,8 +520,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestCustomizerOverridesParameterThenQueryIncludesParameter() { ClientRegistration clientRegistration = this.oidcRegistration; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); this.resolver.setAuthorizationRequestCustomizer((builder) -> builder.parameters((params) -> { params.put("appid", params.get("client_id")); params.remove("client_id"); @@ -579,8 +537,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { OAuth2AuthorizationRequestResolver resolver = new DefaultOAuth2AuthorizationRequestResolver( this.clientRegistrationRepository); String requestUri = this.authorizationRequestBaseUri + "/" + this.registration2.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = resolver.resolve(request); assertThat(authorizationRequest.getRedirectUri()) .isEqualTo("http://localhost/login/oauth2/code/" + this.registration2.getRegistrationId()); @@ -590,8 +547,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { public void resolveWhenAuthorizationRequestProvideCodeChallengeMethod() { ClientRegistration clientRegistration = this.pkceClientRegistration; String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getAdditionalParameters().containsKey(PkceParameterNames.CODE_CHALLENGE_METHOD)) .isTrue(); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java index a6f1f9bc96..440b3a6383 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java @@ -72,6 +72,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests for {@link OAuth2AuthorizationCodeGrantFilter}. @@ -154,8 +155,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests { @Test public void doFilterWhenNotAuthorizationResponseThenNotProcessed() throws Exception { String requestUri = "/path"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); // NOTE: A valid Authorization Response contains either a 'code' or 'error' // parameter. MockHttpServletResponse response = new MockHttpServletResponse(); @@ -328,8 +328,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests { @Test public void doFilterWhenAuthorizationSucceedsAndHasSavedRequestThenRedirectToSavedRequest() throws Exception { String requestUri = "/saved-request"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); MockHttpServletResponse response = new MockHttpServletResponse(); RequestCache requestCache = new HttpSessionRequestCache(); requestCache.saveRequest(request, response); @@ -430,8 +429,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests { private static MockHttpServletRequest createAuthorizationRequest(String requestUri, Map parameters) { - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); if (!CollectionUtils.isEmpty(parameters)) { parameters.forEach(request::addParameter); request.setQueryString(parameters.entrySet() diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java index 59676b1461..bc51ba9e7b 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java @@ -55,6 +55,7 @@ import static org.mockito.BDDMockito.willThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests for {@link OAuth2AuthorizationRequestRedirectFilter}. @@ -127,8 +128,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { @Test public void doFilterWhenNotAuthorizationRequestThenNextFilter() throws Exception { String requestUri = "/path"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.doFilter(request, response, filterChain); @@ -139,8 +139,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { public void doFilterWhenAuthorizationRequestWithInvalidClientThenStatusInternalServerError() throws Exception { String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" + this.registration1.getRegistrationId() + "-invalid"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.doFilter(request, response, filterChain); @@ -154,8 +153,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { throws Exception { String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" + this.registration1.getRegistrationId() + "-invalid"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.setAuthenticationFailureHandler((request1, response1, ex) -> { @@ -178,8 +176,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { public void doFilterWhenAuthorizationRequestOAuth2LoginThenRedirectForAuthorization() throws Exception { String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" + this.registration1.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.doFilter(request, response, filterChain); @@ -193,8 +190,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { public void doFilterWhenAuthorizationRequestOAuth2LoginThenAuthorizationRequestSaved() throws Exception { String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" + this.registration2.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); AuthorizationRequestRepository authorizationRequestRepository = mock( @@ -212,8 +208,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { this.filter = new OAuth2AuthorizationRequestRedirectFilter(this.clientRegistrationRepository, authorizationRequestBaseUri); String requestUri = authorizationRequestBaseUri + "/" + this.registration1.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.doFilter(request, response, filterChain); @@ -227,8 +222,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { public void doFilterWhenNotAuthorizationRequestAndClientAuthorizationRequiredExceptionThrownThenRedirectForAuthorization() throws Exception { String requestUri = "/path"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); willThrow(new ClientAuthorizationRequiredException(this.registration1.getRegistrationId())).given(filterChain) @@ -245,8 +239,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { public void doFilterWhenNotAuthorizationRequestAndClientAuthorizationRequiredExceptionThrownButAuthorizationRequestNotResolvedThenStatusInternalServerError() throws Exception { String requestUri = "/path"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); willThrow(new ClientAuthorizationRequiredException(this.registration1.getRegistrationId())).given(filterChain) @@ -266,8 +259,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { throws Exception { String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" + this.registration1.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); request.addParameter("idp", "https://other.provider.com"); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -295,8 +287,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { throws Exception { String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" + this.registration1.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); String loginHintParamName = "login_hint"; request.addParameter(loginHintParamName, "user@provider.com"); MockHttpServletResponse response = new MockHttpServletResponse(); @@ -335,8 +326,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { throws Exception { String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" + this.registration1.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); RedirectStrategy customRedirectStrategy = (httpRequest, httpResponse, url) -> { @@ -363,8 +353,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { public void doFilterWhenNotAuthorizationRequestAndClientAuthorizationRequiredExceptionThrownThenSaveRequestBeforeCommitted() throws Exception { String requestUri = "/path"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); willAnswer((invocation) -> assertThat((invocation.getArgument(1)).isCommitted()).isFalse()) diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java index 3dee3c0cc0..5e7b5a4211 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java @@ -69,6 +69,7 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests for {@link OAuth2LoginAuthenticationFilter}. @@ -163,8 +164,7 @@ public class OAuth2LoginAuthenticationFilterTests { @Test public void doFilterWhenNotAuthorizationResponseThenNextFilter() throws Exception { String requestUri = "/path"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.doFilter(request, response, filterChain); @@ -176,8 +176,7 @@ public class OAuth2LoginAuthenticationFilterTests { @Test public void doFilterWhenAuthorizationResponseInvalidThenInvalidRequestError() throws Exception { String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = get(requestUri).build(); // NOTE: // A valid Authorization Response contains either a 'code' or 'error' parameter. // Don't set it to force an invalid Authorization Response. @@ -198,10 +197,9 @@ public class OAuth2LoginAuthenticationFilterTests { public void doFilterWhenAuthorizationResponseAuthorizationRequestNotFoundThenAuthorizationRequestNotFoundError() throws Exception { String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, "state"); + MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, "state") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.doFilter(request, response, filterChain); @@ -221,10 +219,9 @@ public class OAuth2LoginAuthenticationFilterTests { throws Exception { String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); String state = "state"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, "state"); + MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, state) + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); // @formatter:off @@ -258,10 +255,9 @@ public class OAuth2LoginAuthenticationFilterTests { public void doFilterWhenAuthorizationResponseValidThenAuthorizationRequestRemoved() throws Exception { String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); String state = "state"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, state); + MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, state) + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.setUpAuthorizationRequest(request, response, this.registration2, state); @@ -274,10 +270,9 @@ public class OAuth2LoginAuthenticationFilterTests { public void doFilterWhenAuthorizationResponseValidThenAuthorizedClientSaved() throws Exception { String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId(); String state = "state"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, state); + MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, state) + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.setUpAuthorizationRequest(request, response, this.registration1, state); @@ -300,10 +295,9 @@ public class OAuth2LoginAuthenticationFilterTests { this.filter.setAuthenticationManager(this.authenticationManager); String requestUri = "/login/oauth2/custom/" + this.registration2.getRegistrationId(); String state = "state"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, state); + MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, state) + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.setUpAuthorizationRequest(request, response, this.registration2, state); @@ -319,13 +313,9 @@ public class OAuth2LoginAuthenticationFilterTests { throws Exception { String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); String state = "state"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setScheme("http"); - request.setServerName("localhost"); - request.setServerPort(80); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, "state"); + MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, state) + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.setUpAuthorizationRequest(request, response, this.registration2, state); @@ -350,13 +340,10 @@ public class OAuth2LoginAuthenticationFilterTests { throws Exception { String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); String state = "state"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setScheme("https"); - request.setServerName("example.com"); - request.setServerPort(443); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, "state"); + MockHttpServletRequest request = get("https://example.com:443" + requestUri) + .param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, state) + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.setUpAuthorizationRequest(request, response, this.registration2, state); @@ -381,13 +368,10 @@ public class OAuth2LoginAuthenticationFilterTests { throws Exception { String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); String state = "state"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setScheme("https"); - request.setServerName("example.com"); - request.setServerPort(9090); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, "state"); + MockHttpServletRequest request = get("https://example.com:9090" + requestUri) + .param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, state) + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.setUpAuthorizationRequest(request, response, this.registration2, state); @@ -411,10 +395,9 @@ public class OAuth2LoginAuthenticationFilterTests { public void attemptAuthenticationShouldSetAuthenticationDetailsOnAuthenticationResult() throws Exception { String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId(); String state = "state"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, state); + MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, state) + .build(); WebAuthenticationDetails webAuthenticationDetails = mock(WebAuthenticationDetails.class); given(this.authenticationDetailsSource.buildDetails(any())).willReturn(webAuthenticationDetails); MockHttpServletResponse response = new MockHttpServletResponse(); @@ -430,10 +413,9 @@ public class OAuth2LoginAuthenticationFilterTests { this.filter.setAuthenticationResultConverter((authentication) -> null); String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId(); String state = "state"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, state); + MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, state) + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); this.setUpAuthorizationRequest(request, response, this.registration1, state); this.setUpAuthenticationResult(this.registration1); @@ -448,10 +430,9 @@ public class OAuth2LoginAuthenticationFilterTests { authentication.getAuthorities(), authentication.getClientRegistration().getRegistrationId())); String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId(); String state = "state"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, state); + MockHttpServletRequest request = get(requestUri).param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, state) + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); this.setUpAuthorizationRequest(request, response, this.registration1, state); this.setUpAuthenticationResult(this.registration1); diff --git a/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle b/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle index b05c1bbd57..64511c5079 100644 --- a/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle +++ b/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle @@ -108,6 +108,7 @@ dependencies { optional 'com.fasterxml.jackson.core:jackson-databind' optional 'org.springframework:spring-jdbc' + testImplementation project(path: ':spring-security-web', configuration: 'tests') testImplementation 'com.squareup.okhttp3:mockwebserver' testImplementation "org.assertj:assertj-core" testImplementation "org.skyscreamer:jsonassert" diff --git a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/OpenSaml4AuthenticationTokenConverterTests.java b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/OpenSaml4AuthenticationTokenConverterTests.java index 57f4221260..1635b4aa53 100644 --- a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/OpenSaml4AuthenticationTokenConverterTests.java +++ b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/OpenSaml4AuthenticationTokenConverterTests.java @@ -42,6 +42,7 @@ import org.springframework.security.saml2.provider.service.authentication.TestOp import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import org.springframework.util.StreamUtils; import org.springframework.web.util.UriUtils; @@ -216,15 +217,11 @@ public final class OpenSaml4AuthenticationTokenConverterTests { } private MockHttpServletRequest post(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("POST", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.post(uri).build(); } private MockHttpServletRequest get(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.get(uri).build(); } private T signed(T toSign) { diff --git a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverterTests.java b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverterTests.java index abcf3a4c78..ffb2196836 100644 --- a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverterTests.java +++ b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverterTests.java @@ -42,6 +42,7 @@ import org.springframework.security.saml2.provider.service.authentication.TestOp import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import org.springframework.util.StreamUtils; import org.springframework.web.util.UriUtils; @@ -216,15 +217,11 @@ public final class OpenSamlAuthenticationTokenConverterTests { } private MockHttpServletRequest post(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("POST", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.post(uri).build(); } private MockHttpServletRequest get(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.get(uri).build(); } private T signed(T toSign) { diff --git a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolverTests.java b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolverTests.java index d2fdb67c74..4ca03be726 100644 --- a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolverTests.java +++ b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolverTests.java @@ -28,6 +28,7 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import static org.assertj.core.api.Assertions.assertThat; @@ -102,9 +103,7 @@ public class OpenSaml4AuthenticationRequestResolverTests { } private MockHttpServletRequest givenRequest(String path) { - MockHttpServletRequest request = new MockHttpServletRequest("GET", path); - request.setServletPath(path); - return request; + return TestMockHttpServletRequests.get(path).build(); } } diff --git a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestValidatorParametersResolverTests.java b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestValidatorParametersResolverTests.java index c7aeb8b878..cdf9cd7712 100644 --- a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestValidatorParametersResolverTests.java +++ b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestValidatorParametersResolverTests.java @@ -36,6 +36,7 @@ import org.springframework.security.saml2.provider.service.authentication.logout import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -135,15 +136,11 @@ public final class OpenSaml4LogoutRequestValidatorParametersResolverTests { } private MockHttpServletRequest post(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("POST", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.post(uri).build(); } private MockHttpServletRequest get(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.get(uri).build(); } private String serialize(XMLObject object) { diff --git a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestValidatorParametersResolverTests.java b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestValidatorParametersResolverTests.java index b57db0a895..e036b749cc 100644 --- a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestValidatorParametersResolverTests.java +++ b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestValidatorParametersResolverTests.java @@ -36,6 +36,7 @@ import org.springframework.security.saml2.provider.service.authentication.logout import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -135,15 +136,11 @@ public final class OpenSamlLogoutRequestValidatorParametersResolverTests { } private MockHttpServletRequest post(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("POST", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.post(uri).build(); } private MockHttpServletRequest get(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.get(uri).build(); } private String serialize(XMLObject object) { diff --git a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/OpenSaml5AuthenticationTokenConverterTests.java b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/OpenSaml5AuthenticationTokenConverterTests.java index 1c35ec58e3..dcf7617c6b 100644 --- a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/OpenSaml5AuthenticationTokenConverterTests.java +++ b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/OpenSaml5AuthenticationTokenConverterTests.java @@ -42,6 +42,7 @@ import org.springframework.security.saml2.provider.service.authentication.TestOp import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import org.springframework.util.StreamUtils; import org.springframework.web.util.UriUtils; @@ -216,15 +217,11 @@ public final class OpenSaml5AuthenticationTokenConverterTests { } private MockHttpServletRequest post(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("POST", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.post(uri).build(); } private MockHttpServletRequest get(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.get(uri).build(); } private T signed(T toSign) { diff --git a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5AuthenticationRequestResolverTests.java b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5AuthenticationRequestResolverTests.java index 9adf06a6fc..8e4730c561 100644 --- a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5AuthenticationRequestResolverTests.java +++ b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml5AuthenticationRequestResolverTests.java @@ -28,6 +28,7 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import static org.assertj.core.api.Assertions.assertThat; @@ -102,9 +103,7 @@ public class OpenSaml5AuthenticationRequestResolverTests { } private MockHttpServletRequest givenRequest(String path) { - MockHttpServletRequest request = new MockHttpServletRequest("GET", path); - request.setServletPath(path); - return request; + return TestMockHttpServletRequests.get(path).build(); } } diff --git a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutRequestValidatorParametersResolverTests.java b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutRequestValidatorParametersResolverTests.java index 8ec0cef306..af13bef179 100644 --- a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutRequestValidatorParametersResolverTests.java +++ b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml5LogoutRequestValidatorParametersResolverTests.java @@ -36,6 +36,7 @@ import org.springframework.security.saml2.provider.service.authentication.logout import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -135,15 +136,11 @@ public final class OpenSaml5LogoutRequestValidatorParametersResolverTests { } private MockHttpServletRequest post(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("POST", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.post(uri).build(); } private MockHttpServletRequest get(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.get(uri).build(); } private String serialize(XMLObject object) { diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/RequestMatcherMetadataResponseResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/RequestMatcherMetadataResponseResolverTests.java index 0684218ffd..1145cca686 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/RequestMatcherMetadataResponseResolverTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/RequestMatcherMetadataResponseResolverTests.java @@ -30,6 +30,7 @@ import org.springframework.security.saml2.provider.service.registration.InMemory import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -121,9 +122,7 @@ public final class RequestMatcherMetadataResponseResolverTests { } private MockHttpServletRequest get(String uri) { - MockHttpServletRequest request = new MockHttpServletRequest("GET", uri); - request.setServletPath(uri); - return request; + return TestMockHttpServletRequests.get(uri).build(); } private RelyingPartyRegistration withEntityId(String entityId) { diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutRequestFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutRequestFilterTests.java index 32c4b7ed8b..7d95b7765a 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutRequestFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutRequestFilterTests.java @@ -46,6 +46,7 @@ import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post; /** * Tests for {@link Saml2LogoutRequestFilter} @@ -76,9 +77,8 @@ public class Saml2LogoutRequestFilterTests { RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); - request.setServletPath("/logout/saml2/slo"); - request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request"); + MockHttpServletRequest request = post("/logout/saml2/slo").param(Saml2ParameterNames.SAML_REQUEST, "request") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); given(this.relyingPartyRegistrationResolver.resolve(any(), any())).willReturn(registration); given(this.logoutRequestValidator.validate(any())).willReturn(Saml2LogoutValidatorResult.success()); @@ -105,9 +105,8 @@ public class Saml2LogoutRequestFilterTests { given(this.securityContextHolderStrategy.getContext()).willReturn(new SecurityContextImpl(authentication)); this.logoutRequestProcessingFilter.setSecurityContextHolderStrategy(this.securityContextHolderStrategy); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); - request.setServletPath("/logout/saml2/slo"); - request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request"); + MockHttpServletRequest request = post("/logout/saml2/slo").param(Saml2ParameterNames.SAML_REQUEST, "request") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); given(this.relyingPartyRegistrationResolver.resolve(any(), any())).willReturn(registration); given(this.logoutRequestValidator.validate(any())).willReturn(Saml2LogoutValidatorResult.success()); @@ -127,9 +126,7 @@ public class Saml2LogoutRequestFilterTests { public void doFilterWhenRequestMismatchesThenNoLogout() throws Exception { Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout"); - request.setServletPath("/logout"); - request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "response"); + MockHttpServletRequest request = post("/logout").param(Saml2ParameterNames.SAML_RESPONSE, "response").build(); MockHttpServletResponse response = new MockHttpServletResponse(); this.logoutRequestProcessingFilter.doFilterInternal(request, response, new MockFilterChain()); verifyNoInteractions(this.logoutRequestValidator, this.logoutHandler); @@ -139,8 +136,7 @@ public class Saml2LogoutRequestFilterTests { public void doFilterWhenNoSamlRequestOrResponseThenNoLogout() throws Exception { Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); - request.setServletPath("/logout/saml2/slo"); + MockHttpServletRequest request = post("/logout/saml2/slo").build(); MockHttpServletResponse response = new MockHttpServletResponse(); this.logoutRequestProcessingFilter.doFilterInternal(request, response, new MockFilterChain()); verifyNoInteractions(this.logoutRequestValidator, this.logoutHandler); @@ -153,9 +149,8 @@ public class Saml2LogoutRequestFilterTests { .build(); Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); - request.setServletPath("/logout/saml2/slo"); - request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request"); + MockHttpServletRequest request = post("/logout/saml2/slo").param(Saml2ParameterNames.SAML_REQUEST, "request") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); Saml2LogoutResponse logoutResponse = Saml2LogoutResponse.withRelyingPartyRegistration(registration) .samlResponse("response") @@ -182,7 +177,6 @@ public class Saml2LogoutRequestFilterTests { Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); - request.setServletPath("/logout/saml2/slo"); request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request"); MockHttpServletResponse response = new MockHttpServletResponse(); RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full() @@ -210,9 +204,8 @@ public class Saml2LogoutRequestFilterTests { public void doFilterWhenInvalidBindingErrorLogoutResponseIsPosted() throws Exception { Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); - request.setServletPath("/logout/saml2/slo"); - request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request"); + MockHttpServletRequest request = post("/logout/saml2/slo").param(Saml2ParameterNames.SAML_REQUEST, "request") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full() .assertingPartyMetadata((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST)) @@ -242,9 +235,8 @@ public class Saml2LogoutRequestFilterTests { public void doFilterWhenNoErrorResponseCanBeGeneratedThen401() throws Exception { Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); - request.setServletPath("/logout/saml2/slo"); - request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request"); + MockHttpServletRequest request = post("/logout/saml2/slo").param(Saml2ParameterNames.SAML_REQUEST, "request") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full() .assertingPartyMetadata((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST)) diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutResponseFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutResponseFilterTests.java index 5973f9589e..77f43dad57 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutResponseFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutResponseFilterTests.java @@ -43,6 +43,8 @@ import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.mock; import static org.mockito.BDDMockito.verify; import static org.mockito.Mockito.verifyNoInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post; /** * Tests for {@link Saml2LogoutResponseFilter} @@ -74,9 +76,8 @@ public class Saml2LogoutResponseFilterTests { public void doFilterWhenSamlResponsePostThenLogout() throws Exception { Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); - request.setServletPath("/logout/saml2/slo"); - request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "response"); + MockHttpServletRequest request = post("/logout/saml2/slo").param(Saml2ParameterNames.SAML_RESPONSE, "response") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); given(this.relyingPartyRegistrationResolver.resolve(request, "registration-id")).willReturn(registration); @@ -94,8 +95,7 @@ public class Saml2LogoutResponseFilterTests { public void doFilterWhenSamlResponseRedirectThenLogout() throws Exception { Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("GET", "/logout/saml2/slo"); - request.setServletPath("/logout/saml2/slo"); + MockHttpServletRequest request = get("/logout/saml2/slo").build(); request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "response"); MockHttpServletResponse response = new MockHttpServletResponse(); RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full() @@ -116,9 +116,7 @@ public class Saml2LogoutResponseFilterTests { public void doFilterWhenRequestMismatchesThenNoLogout() throws Exception { Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout"); - request.setServletPath("/logout"); - request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request"); + MockHttpServletRequest request = post("/logout").param(Saml2ParameterNames.SAML_REQUEST, "request").build(); MockHttpServletResponse response = new MockHttpServletResponse(); this.logoutResponseProcessingFilter.doFilterInternal(request, response, new MockFilterChain()); verifyNoInteractions(this.logoutResponseValidator, this.logoutSuccessHandler); @@ -128,8 +126,7 @@ public class Saml2LogoutResponseFilterTests { public void doFilterWhenNoSamlRequestOrResponseThenNoLogout() throws Exception { Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); - request.setServletPath("/logout/saml2/slo"); + MockHttpServletRequest request = post("/logout/saml2/slo").build(); MockHttpServletResponse response = new MockHttpServletResponse(); this.logoutResponseProcessingFilter.doFilterInternal(request, response, new MockFilterChain()); verifyNoInteractions(this.logoutResponseValidator, this.logoutSuccessHandler); @@ -139,9 +136,8 @@ public class Saml2LogoutResponseFilterTests { public void doFilterWhenValidatorFailsThenStops() throws Exception { Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); - request.setServletPath("/logout/saml2/slo"); - request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "response"); + MockHttpServletRequest request = post("/logout/saml2/slo").param(Saml2ParameterNames.SAML_RESPONSE, "response") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build(); given(this.relyingPartyRegistrationResolver.resolve(request, "registration-id")).willReturn(registration); @@ -160,9 +156,8 @@ public class Saml2LogoutResponseFilterTests { public void doFilterWhenNoRelyingPartyLogoutThen401() throws Exception { Authentication authentication = new TestingAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(authentication); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo"); - request.setServletPath("/logout/saml2/slo"); - request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "response"); + MockHttpServletRequest request = post("/logout/saml2/slo").param(Saml2ParameterNames.SAML_RESPONSE, "response") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full() .singleLogoutServiceLocation(null) diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2RelyingPartyInitiatedLogoutSuccessHandlerTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2RelyingPartyInitiatedLogoutSuccessHandlerTests.java index 2b0b837e43..1260d7fd1d 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2RelyingPartyInitiatedLogoutSuccessHandlerTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2RelyingPartyInitiatedLogoutSuccessHandlerTests.java @@ -39,6 +39,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.mock; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post; /** * Tests for {@link Saml2RelyingPartyInitiatedLogoutSuccessHandler} @@ -72,8 +73,7 @@ public class Saml2RelyingPartyInitiatedLogoutSuccessHandlerTests { Saml2LogoutRequest logoutRequest = Saml2LogoutRequest.withRelyingPartyRegistration(registration) .samlRequest("request") .build(); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/saml2/logout"); - request.setServletPath("/saml2/logout"); + MockHttpServletRequest request = post("/saml2/logout").build(); MockHttpServletResponse response = new MockHttpServletResponse(); given(this.logoutRequestResolver.resolve(any(), any())).willReturn(logoutRequest); this.logoutRequestSuccessHandler.onLogoutSuccess(request, response, authentication); @@ -92,8 +92,7 @@ public class Saml2RelyingPartyInitiatedLogoutSuccessHandlerTests { Saml2LogoutRequest logoutRequest = Saml2LogoutRequest.withRelyingPartyRegistration(registration) .samlRequest("request") .build(); - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/saml2/logout"); - request.setServletPath("/saml2/logout"); + MockHttpServletRequest request = post("/saml2/logout").build(); MockHttpServletResponse response = new MockHttpServletResponse(); given(this.logoutRequestResolver.resolve(any(), any())).willReturn(logoutRequest); this.logoutRequestSuccessHandler.onLogoutSuccess(request, response, authentication); diff --git a/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java b/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java index 2e8f7a552a..a1b1a5300f 100644 --- a/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java +++ b/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java @@ -64,6 +64,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * @author Luke Taylor @@ -96,8 +97,7 @@ public class FilterChainProxyTests { }).given(this.filter).doFilter(any(), any(), any()); this.fcp = new FilterChainProxy(new DefaultSecurityFilterChain(this.matcher, Arrays.asList(this.filter))); this.fcp.setFilterChainValidator(mock(FilterChainProxy.FilterChainValidator.class)); - this.request = new MockHttpServletRequest("GET", ""); - this.request.setServletPath("/path"); + this.request = get("/path").build(); this.response = new MockHttpServletResponse(); this.chain = mock(FilterChain.class); } diff --git a/web/src/test/java/org/springframework/security/web/FilterInvocationTests.java b/web/src/test/java/org/springframework/security/web/FilterInvocationTests.java index 5b6fc258ea..58a6fd2f98 100644 --- a/web/src/test/java/org/springframework/security/web/FilterInvocationTests.java +++ b/web/src/test/java/org/springframework/security/web/FilterInvocationTests.java @@ -34,6 +34,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.Mockito.mock; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests {@link FilterInvocation}. @@ -45,14 +46,8 @@ public class FilterInvocationTests { @Test public void testGettersAndStringMethods() { - MockHttpServletRequest request = new MockHttpServletRequest(null, null); - request.setServletPath("/HelloWorld"); - request.setPathInfo("/some/more/segments.html"); - request.setServerName("localhost"); - request.setScheme("http"); - request.setServerPort(80); - request.setContextPath("/mycontext"); - request.setRequestURI("/mycontext/HelloWorld/some/more/segments.html"); + MockHttpServletRequest request = get().requestUri("/mycontext", "/HelloWorld", "/some/more/segments.html") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); FilterInvocation fi = new FilterInvocation(request, response, chain); @@ -62,7 +57,7 @@ public class FilterInvocationTests { assertThat(fi.getHttpResponse()).isEqualTo(response); assertThat(fi.getChain()).isEqualTo(chain); assertThat(fi.getRequestUrl()).isEqualTo("/HelloWorld/some/more/segments.html"); - assertThat(fi.toString()).isEqualTo("filter invocation [/HelloWorld/some/more/segments.html]"); + assertThat(fi.toString()).isEqualTo("filter invocation [GET /HelloWorld/some/more/segments.html]"); assertThat(fi.getFullRequestUrl()).isEqualTo("http://localhost/mycontext/HelloWorld/some/more/segments.html"); } @@ -89,34 +84,23 @@ public class FilterInvocationTests { @Test public void testStringMethodsWithAQueryString() { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setQueryString("foo=bar"); - request.setServletPath("/HelloWorld"); - request.setServerName("localhost"); - request.setScheme("http"); - request.setServerPort(80); - request.setContextPath("/mycontext"); - request.setRequestURI("/mycontext/HelloWorld"); + MockHttpServletRequest request = get().requestUri("/mycontext", "/HelloWorld", null) + .queryString("foo=bar") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterInvocation fi = new FilterInvocation(request, response, mock(FilterChain.class)); assertThat(fi.getRequestUrl()).isEqualTo("/HelloWorld?foo=bar"); - assertThat(fi.toString()).isEqualTo("filter invocation [/HelloWorld?foo=bar]"); + assertThat(fi.toString()).isEqualTo("filter invocation [GET /HelloWorld?foo=bar]"); assertThat(fi.getFullRequestUrl()).isEqualTo("http://localhost/mycontext/HelloWorld?foo=bar"); } @Test public void testStringMethodsWithoutAnyQueryString() { - MockHttpServletRequest request = new MockHttpServletRequest(null, null); - request.setServletPath("/HelloWorld"); - request.setServerName("localhost"); - request.setScheme("http"); - request.setServerPort(80); - request.setContextPath("/mycontext"); - request.setRequestURI("/mycontext/HelloWorld"); + MockHttpServletRequest request = get().requestUri("/mycontext", "/HelloWorld", null).build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterInvocation fi = new FilterInvocation(request, response, mock(FilterChain.class)); assertThat(fi.getRequestUrl()).isEqualTo("/HelloWorld"); - assertThat(fi.toString()).isEqualTo("filter invocation [/HelloWorld]"); + assertThat(fi.toString()).isEqualTo("filter invocation [GET /HelloWorld]"); assertThat(fi.getFullRequestUrl()).isEqualTo("http://localhost/mycontext/HelloWorld"); } diff --git a/web/src/test/java/org/springframework/security/web/RequestMatcherRedirectFilterTests.java b/web/src/test/java/org/springframework/security/web/RequestMatcherRedirectFilterTests.java index 29c8d46abf..4ecb023d95 100644 --- a/web/src/test/java/org/springframework/security/web/RequestMatcherRedirectFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/RequestMatcherRedirectFilterTests.java @@ -29,6 +29,7 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests for {@link RequestMatcherRedirectFilter}. @@ -42,9 +43,7 @@ public class RequestMatcherRedirectFilterTests { RequestMatcherRedirectFilter filter = new RequestMatcherRedirectFilter(new AntPathRequestMatcher("/context"), "/test"); - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath("/context"); - + MockHttpServletRequest request = get("/context").build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -61,8 +60,7 @@ public class RequestMatcherRedirectFilterTests { RequestMatcherRedirectFilter filter = new RequestMatcherRedirectFilter(new AntPathRequestMatcher("/context"), "/test"); - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath("/test"); + MockHttpServletRequest request = get("/test").build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); diff --git a/web/src/test/java/org/springframework/security/web/access/ExceptionTranslationFilterTests.java b/web/src/test/java/org/springframework/security/web/access/ExceptionTranslationFilterTests.java index 89159fd737..11fe3e41d2 100644 --- a/web/src/test/java/org/springframework/security/web/access/ExceptionTranslationFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/access/ExceptionTranslationFilterTests.java @@ -58,6 +58,7 @@ import static org.mockito.BDDMockito.willThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests {@link ExceptionTranslationFilter}. @@ -86,13 +87,7 @@ public class ExceptionTranslationFilterTests { @Test public void testAccessDeniedWhenAnonymous() throws Exception { // Setup our HTTP request - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath("/secure/page.html"); - request.setServerPort(80); - request.setScheme("http"); - request.setServerName("localhost"); - request.setContextPath("/mycontext"); - request.setRequestURI("/mycontext/secure/page.html"); + MockHttpServletRequest request = get().requestUri("/mycontext", "/secure/page.html", null).build(); // Setup the FilterChain to thrown an access denied exception FilterChain fc = mockFilterChainWithException(new AccessDeniedException("")); // Setup SecurityContextHolder, as filter needs to check if user is @@ -129,13 +124,7 @@ public class ExceptionTranslationFilterTests { @Test public void testAccessDeniedWithRememberMe() throws Exception { // Setup our HTTP request - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath("/secure/page.html"); - request.setServerPort(80); - request.setScheme("http"); - request.setServerName("localhost"); - request.setContextPath("/mycontext"); - request.setRequestURI("/mycontext/secure/page.html"); + MockHttpServletRequest request = get().requestUri("/mycontext", "/secure/page.html", null).build(); // Setup the FilterChain to thrown an access denied exception FilterChain fc = mockFilterChainWithException(new AccessDeniedException("")); // Setup SecurityContextHolder, as filter needs to check if user is remembered @@ -155,8 +144,7 @@ public class ExceptionTranslationFilterTests { @Test public void testAccessDeniedWhenNonAnonymous() throws Exception { // Setup our HTTP request - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath("/secure/page.html"); + MockHttpServletRequest request = get("/secure/page.html").build(); // Setup the FilterChain to thrown an access denied exception FilterChain fc = mockFilterChainWithException(new AccessDeniedException("")); // Setup SecurityContextHolder, as filter needs to check if user is @@ -178,8 +166,7 @@ public class ExceptionTranslationFilterTests { @Test public void testLocalizedErrorMessages() throws Exception { // Setup our HTTP request - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath("/secure/page.html"); + MockHttpServletRequest request = get("/secure/page.html").build(); // Setup the FilterChain to thrown an access denied exception FilterChain fc = mockFilterChainWithException(new AccessDeniedException("")); // Setup SecurityContextHolder, as filter needs to check if user is @@ -202,13 +189,7 @@ public class ExceptionTranslationFilterTests { @Test public void redirectedToLoginFormAndSessionShowsOriginalTargetWhenAuthenticationException() throws Exception { // Setup our HTTP request - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath("/secure/page.html"); - request.setServerPort(80); - request.setScheme("http"); - request.setServerName("localhost"); - request.setContextPath("/mycontext"); - request.setRequestURI("/mycontext/secure/page.html"); + MockHttpServletRequest request = get().requestUri("/mycontext", "/secure/page.html", null).build(); // Setup the FilterChain to thrown an authentication failure exception FilterChain fc = mockFilterChainWithException(new BadCredentialsException("")); // Test @@ -225,13 +206,9 @@ public class ExceptionTranslationFilterTests { public void redirectedToLoginFormAndSessionShowsOriginalTargetWithExoticPortWhenAuthenticationException() throws Exception { // Setup our HTTP request - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath("/secure/page.html"); - request.setServerPort(8080); - request.setScheme("http"); - request.setServerName("localhost"); - request.setContextPath("/mycontext"); - request.setRequestURI("/mycontext/secure/page.html"); + MockHttpServletRequest request = get("http://localhost:8080") + .requestUri("/mycontext", "/secure/page.html", null) + .build(); // Setup the FilterChain to thrown an authentication failure exception FilterChain fc = mockFilterChainWithException(new BadCredentialsException("")); // Test @@ -258,8 +235,7 @@ public class ExceptionTranslationFilterTests { @Test public void successfulAccessGrant() throws Exception { // Setup our HTTP request - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath("/secure/page.html"); + MockHttpServletRequest request = get("/secure/page.html").build(); // Test ExceptionTranslationFilter filter = new ExceptionTranslationFilter(this.mockEntryPoint); assertThat(filter.getAuthenticationEntryPoint()).isSameAs(this.mockEntryPoint); diff --git a/web/src/test/java/org/springframework/security/web/access/channel/ChannelProcessingFilterTests.java b/web/src/test/java/org/springframework/security/web/access/channel/ChannelProcessingFilterTests.java index 348b3b2152..ad3f3afa66 100644 --- a/web/src/test/java/org/springframework/security/web/access/channel/ChannelProcessingFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/access/channel/ChannelProcessingFilterTests.java @@ -32,6 +32,7 @@ import org.springframework.security.web.access.intercept.FilterInvocationSecurit import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.Mockito.mock; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests {@link ChannelProcessingFilter}. @@ -81,9 +82,8 @@ public class ChannelProcessingFilterTests { filter.setChannelDecisionManager(new MockChannelDecisionManager(true, "SOME_ATTRIBUTE")); MockFilterInvocationDefinitionMap fids = new MockFilterInvocationDefinitionMap("/path", true, "SOME_ATTRIBUTE"); filter.setSecurityMetadataSource(fids); - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/path").build(); request.setQueryString("info=now"); - request.setServletPath("/path"); MockHttpServletResponse response = new MockHttpServletResponse(); filter.doFilter(request, response, mock(FilterChain.class)); } @@ -94,9 +94,8 @@ public class ChannelProcessingFilterTests { filter.setChannelDecisionManager(new MockChannelDecisionManager(false, "SOME_ATTRIBUTE")); MockFilterInvocationDefinitionMap fids = new MockFilterInvocationDefinitionMap("/path", true, "SOME_ATTRIBUTE"); filter.setSecurityMetadataSource(fids); - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/path").build(); request.setQueryString("info=now"); - request.setServletPath("/path"); MockHttpServletResponse response = new MockHttpServletResponse(); filter.doFilter(request, response, mock(FilterChain.class)); } @@ -107,9 +106,8 @@ public class ChannelProcessingFilterTests { filter.setChannelDecisionManager(new MockChannelDecisionManager(false, "NOT_USED")); MockFilterInvocationDefinitionMap fids = new MockFilterInvocationDefinitionMap("/path", true, "NOT_USED"); filter.setSecurityMetadataSource(fids); - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/PATH_NOT_MATCHING_CONFIG_ATTRIBUTE").build(); request.setQueryString("info=now"); - request.setServletPath("/PATH_NOT_MATCHING_CONFIG_ATTRIBUTE"); MockHttpServletResponse response = new MockHttpServletResponse(); filter.doFilter(request, response, mock(FilterChain.class)); } diff --git a/web/src/test/java/org/springframework/security/web/access/channel/InsecureChannelProcessorTests.java b/web/src/test/java/org/springframework/security/web/access/channel/InsecureChannelProcessorTests.java index 1a3f8f1480..1c290cac06 100644 --- a/web/src/test/java/org/springframework/security/web/access/channel/InsecureChannelProcessorTests.java +++ b/web/src/test/java/org/springframework/security/web/access/channel/InsecureChannelProcessorTests.java @@ -27,6 +27,7 @@ import org.springframework.security.web.FilterInvocation; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.Mockito.mock; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests {@link InsecureChannelProcessor}. @@ -37,13 +38,9 @@ public class InsecureChannelProcessorTests { @Test public void testDecideDetectsAcceptableChannel() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setQueryString("info=true"); - request.setServerName("localhost"); - request.setContextPath("/bigapp"); - request.setServletPath("/servlet"); - request.setScheme("http"); - request.setServerPort(8080); + MockHttpServletRequest request = get("http://localhost:8080").requestUri("/bigapp", "/servlet", null) + .queryString("info=true") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterInvocation fi = new FilterInvocation(request, response, mock(FilterChain.class)); InsecureChannelProcessor processor = new InsecureChannelProcessor(); @@ -53,14 +50,9 @@ public class InsecureChannelProcessorTests { @Test public void testDecideDetectsUnacceptableChannel() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setQueryString("info=true"); - request.setServerName("localhost"); - request.setContextPath("/bigapp"); - request.setServletPath("/servlet"); - request.setScheme("https"); - request.setSecure(true); - request.setServerPort(8443); + MockHttpServletRequest request = get("https://localhost:8443").requestUri("/bigapp", "/servlet", null) + .queryString("info=true") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterInvocation fi = new FilterInvocation(request, response, mock(FilterChain.class)); InsecureChannelProcessor processor = new InsecureChannelProcessor(); diff --git a/web/src/test/java/org/springframework/security/web/access/channel/SecureChannelProcessorTests.java b/web/src/test/java/org/springframework/security/web/access/channel/SecureChannelProcessorTests.java index 005736b336..4263cec233 100644 --- a/web/src/test/java/org/springframework/security/web/access/channel/SecureChannelProcessorTests.java +++ b/web/src/test/java/org/springframework/security/web/access/channel/SecureChannelProcessorTests.java @@ -27,6 +27,7 @@ import org.springframework.security.web.FilterInvocation; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.Mockito.mock; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests {@link SecureChannelProcessor}. @@ -37,14 +38,9 @@ public class SecureChannelProcessorTests { @Test public void testDecideDetectsAcceptableChannel() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setQueryString("info=true"); - request.setServerName("localhost"); - request.setContextPath("/bigapp"); - request.setServletPath("/servlet"); - request.setScheme("https"); - request.setSecure(true); - request.setServerPort(8443); + MockHttpServletRequest request = get("https://localhost:8443").requestUri("/bigapp", "/servlet", null) + .queryString("info=true") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterInvocation fi = new FilterInvocation(request, response, mock(FilterChain.class)); SecureChannelProcessor processor = new SecureChannelProcessor(); @@ -54,13 +50,9 @@ public class SecureChannelProcessorTests { @Test public void testDecideDetectsUnacceptableChannel() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setQueryString("info=true"); - request.setServerName("localhost"); - request.setContextPath("/bigapp"); - request.setServletPath("/servlet"); - request.setScheme("http"); - request.setServerPort(8080); + MockHttpServletRequest request = get("http://localhost:8080").requestUri("/bigapp", "/servlet", null) + .queryString("info=true") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterInvocation fi = new FilterInvocation(request, response, mock(FilterChain.class)); SecureChannelProcessor processor = new SecureChannelProcessor(); diff --git a/web/src/test/java/org/springframework/security/web/access/expression/AbstractVariableEvaluationContextPostProcessorTests.java b/web/src/test/java/org/springframework/security/web/access/expression/AbstractVariableEvaluationContextPostProcessorTests.java index d9bb23e6a2..aec2145136 100644 --- a/web/src/test/java/org/springframework/security/web/access/expression/AbstractVariableEvaluationContextPostProcessorTests.java +++ b/web/src/test/java/org/springframework/security/web/access/expression/AbstractVariableEvaluationContextPostProcessorTests.java @@ -31,6 +31,7 @@ import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.web.FilterInvocation; import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * @author Rob Winch @@ -54,8 +55,7 @@ public class AbstractVariableEvaluationContextPostProcessorTests { @BeforeEach public void setup() { this.processor = new VariableEvaluationContextPostProcessor(); - this.request = new MockHttpServletRequest(); - this.request.setServletPath("/"); + this.request = get("/").build(); this.response = new MockHttpServletResponse(); this.invocation = new FilterInvocation(this.request, this.response, new MockFilterChain()); this.context = new StandardEvaluationContext(); diff --git a/web/src/test/java/org/springframework/security/web/access/intercept/DefaultFilterInvocationSecurityMetadataSourceTests.java b/web/src/test/java/org/springframework/security/web/access/intercept/DefaultFilterInvocationSecurityMetadataSourceTests.java index 8a111bd357..0c6a005587 100644 --- a/web/src/test/java/org/springframework/security/web/access/intercept/DefaultFilterInvocationSecurityMetadataSourceTests.java +++ b/web/src/test/java/org/springframework/security/web/access/intercept/DefaultFilterInvocationSecurityMetadataSourceTests.java @@ -32,6 +32,7 @@ import org.springframework.security.web.util.matcher.RequestMatcher; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.mock; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.request; /** * Tests {@link DefaultFilterInvocationSecurityMetadataSource}. @@ -53,7 +54,7 @@ public class DefaultFilterInvocationSecurityMetadataSourceTests { @Test public void lookupNotRequiringExactMatchSucceedsIfNotMatching() { createFids("/secure/super/**", null); - FilterInvocation fi = createFilterInvocation("/secure/super/somefile.html", null, null, null); + FilterInvocation fi = createFilterInvocation("/secure/super/somefile.html", null, null, "GET"); assertThat(this.fids.getAttributes(fi)).isEqualTo(this.def); } @@ -64,7 +65,7 @@ public class DefaultFilterInvocationSecurityMetadataSourceTests { @Test public void lookupNotRequiringExactMatchSucceedsIfSecureUrlPathContainsUpperCase() { createFids("/secure/super/**", null); - FilterInvocation fi = createFilterInvocation("/secure", "/super/somefile.html", null, null); + FilterInvocation fi = createFilterInvocation("/secure", "/super/somefile.html", null, "GET"); Collection response = this.fids.getAttributes(fi); assertThat(response).isEqualTo(this.def); } @@ -72,7 +73,7 @@ public class DefaultFilterInvocationSecurityMetadataSourceTests { @Test public void lookupRequiringExactMatchIsSuccessful() { createFids("/SeCurE/super/**", null); - FilterInvocation fi = createFilterInvocation("/SeCurE/super/somefile.html", null, null, null); + FilterInvocation fi = createFilterInvocation("/SeCurE/super/somefile.html", null, null, "GET"); Collection response = this.fids.getAttributes(fi); assertThat(response).isEqualTo(this.def); } @@ -80,7 +81,7 @@ public class DefaultFilterInvocationSecurityMetadataSourceTests { @Test public void lookupRequiringExactMatchWithAdditionalSlashesIsSuccessful() { createFids("/someAdminPage.html**", null); - FilterInvocation fi = createFilterInvocation("/someAdminPage.html", null, "a=/test", null); + FilterInvocation fi = createFilterInvocation("/someAdminPage.html", null, "a=/test", "GET"); Collection response = this.fids.getAttributes(fi); assertThat(response); // see SEC-161 (it should truncate after ? // sign).isEqualTo(def) @@ -129,22 +130,19 @@ public class DefaultFilterInvocationSecurityMetadataSourceTests { @Test public void extraQuestionMarkStillMatches() { createFids("/someAdminPage.html*", null); - FilterInvocation fi = createFilterInvocation("/someAdminPage.html", null, null, null); + FilterInvocation fi = createFilterInvocation("/someAdminPage.html", null, null, "GET"); Collection response = this.fids.getAttributes(fi); assertThat(response).isEqualTo(this.def); - fi = createFilterInvocation("/someAdminPage.html", null, "?", null); + fi = createFilterInvocation("/someAdminPage.html", null, "?", "GET"); response = this.fids.getAttributes(fi); assertThat(response).isEqualTo(this.def); } private FilterInvocation createFilterInvocation(String servletPath, String pathInfo, String queryString, String method) { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setRequestURI(null); - request.setMethod(method); - request.setServletPath(servletPath); - request.setPathInfo(pathInfo); - request.setQueryString(queryString); + MockHttpServletRequest request = request(method).requestUri(null, servletPath, pathInfo) + .queryString(queryString) + .build(); return new FilterInvocation(request, new MockHttpServletResponse(), mock(FilterChain.class)); } diff --git a/web/src/test/java/org/springframework/security/web/access/intercept/FilterSecurityInterceptorTests.java b/web/src/test/java/org/springframework/security/web/access/intercept/FilterSecurityInterceptorTests.java index 4408b14147..77370b0cd7 100644 --- a/web/src/test/java/org/springframework/security/web/access/intercept/FilterSecurityInterceptorTests.java +++ b/web/src/test/java/org/springframework/security/web/access/intercept/FilterSecurityInterceptorTests.java @@ -53,6 +53,7 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests {@link FilterSecurityInterceptor}. @@ -188,8 +189,7 @@ public class FilterSecurityInterceptorTests { private FilterInvocation createinvocation() { MockHttpServletResponse response = new MockHttpServletResponse(); - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath("/secure/page.html"); + MockHttpServletRequest request = get("/secure/page.html").build(); FilterChain chain = mock(FilterChain.class); FilterInvocation fi = new FilterInvocation(request, response, chain); return fi; diff --git a/web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java index 00e4de0614..752838d2da 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java @@ -59,6 +59,9 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.Builder; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post; /** * Tests {@link AbstractAuthenticationProcessingFilter}. @@ -75,13 +78,11 @@ public class AbstractAuthenticationProcessingFilterTests { SimpleUrlAuthenticationFailureHandler failureHandler; private MockHttpServletRequest createMockAuthenticationRequest() { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath("/j_mock_post"); - request.setScheme("http"); - request.setServerName("www.example.com"); - request.setRequestURI("/mycontext/j_mock_post"); - request.setContextPath("/mycontext"); - return request; + return withMockAuthenticationRequest().build(); + } + + private Builder withMockAuthenticationRequest() { + return get("www.example.com").requestUri("/mycontext", "/j_mock_post", null); } @BeforeEach @@ -100,12 +101,11 @@ public class AbstractAuthenticationProcessingFilterTests { @Test public void testDefaultProcessesFilterUrlMatchesWithPathParameter() { - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/login;jsessionid=I8MIONOSTHOR"); + MockHttpServletRequest request = post("/login;jsessionid=I8MIONOSTHOR").build(); MockHttpServletResponse response = new MockHttpServletResponse(); MockAuthenticationFilter filter = new MockAuthenticationFilter(); filter.setFilterProcessesUrl("/login"); DefaultHttpFirewall firewall = new DefaultHttpFirewall(); - request.setServletPath("/login;jsessionid=I8MIONOSTHOR"); // the firewall ensures that path parameters are ignored HttpServletRequest firewallRequest = firewall.getFirewalledRequest(request); assertThat(filter.requiresAuthentication(firewallRequest, response)).isTrue(); @@ -114,9 +114,9 @@ public class AbstractAuthenticationProcessingFilterTests { @Test public void testFilterProcessesUrlVariationsRespected() throws Exception { // Setup our HTTP request - MockHttpServletRequest request = createMockAuthenticationRequest(); - request.setServletPath("/j_OTHER_LOCATION"); - request.setRequestURI("/mycontext/j_OTHER_LOCATION"); + MockHttpServletRequest request = withMockAuthenticationRequest() + .requestUri("/mycontext", "/j_OTHER_LOCATION", null) + .build(); // Setup our filter configuration MockFilterConfig config = new MockFilterConfig(null, null); // Setup our expectation that the filter chain will not be invoked, as we redirect @@ -150,9 +150,9 @@ public class AbstractAuthenticationProcessingFilterTests { @Test public void testIgnoresAnyServletPathOtherThanFilterProcessesUrl() throws Exception { // Setup our HTTP request - MockHttpServletRequest request = createMockAuthenticationRequest(); - request.setServletPath("/some.file.html"); - request.setRequestURI("/mycontext/some.file.html"); + MockHttpServletRequest request = withMockAuthenticationRequest() + .requestUri("/mycontext", "/some.file.html", null) + .build(); // Setup our filter configuration MockFilterConfig config = new MockFilterConfig(null, null); // Setup our expectation that the filter chain will be invoked, as our request is @@ -227,9 +227,9 @@ public class AbstractAuthenticationProcessingFilterTests { @Test public void testNormalOperationWithRequestMatcherAndAuthenticationManager() throws Exception { // Setup our HTTP request - MockHttpServletRequest request = createMockAuthenticationRequest(); - request.setServletPath("/j_eradicate_corona_virus"); - request.setRequestURI("/mycontext/j_eradicate_corona_virus"); + MockHttpServletRequest request = withMockAuthenticationRequest() + .requestUri("/mycontext", "/j_eradicate_corona_virus", null) + .build(); HttpSession sessionPreAuth = request.getSession(); // Setup our filter configuration MockFilterConfig config = new MockFilterConfig(null, null); diff --git a/web/src/test/java/org/springframework/security/web/authentication/LoginUrlAuthenticationEntryPointTests.java b/web/src/test/java/org/springframework/security/web/authentication/LoginUrlAuthenticationEntryPointTests.java index 91e2d93cdf..a1483034c0 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/LoginUrlAuthenticationEntryPointTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/LoginUrlAuthenticationEntryPointTests.java @@ -28,6 +28,7 @@ import org.springframework.security.web.PortMapperImpl; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests {@link LoginUrlAuthenticationEntryPoint}. @@ -73,12 +74,7 @@ public class LoginUrlAuthenticationEntryPointTests { @Test public void testHttpsOperationFromOriginalHttpUrl() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setRequestURI("/some_path"); - request.setScheme("http"); - request.setServerName("www.example.com"); - request.setContextPath("/bigWebApp"); - request.setServerPort(80); + MockHttpServletRequest request = get("http://127.0.0.1").requestUri("/bigWebApp", "/some_path", null).build(); MockHttpServletResponse response = new MockHttpServletResponse(); LoginUrlAuthenticationEntryPoint ep = new LoginUrlAuthenticationEntryPoint("/hello"); ep.setPortMapper(new PortMapperImpl()); @@ -87,17 +83,17 @@ public class LoginUrlAuthenticationEntryPointTests { ep.setPortResolver(new MockPortResolver(80, 443)); ep.afterPropertiesSet(); ep.commence(request, response, null); - assertThat(response.getRedirectedUrl()).isEqualTo("https://www.example.com/bigWebApp/hello"); + assertThat(response.getRedirectedUrl()).isEqualTo("https://127.0.0.1/bigWebApp/hello"); request.setServerPort(8080); response = new MockHttpServletResponse(); ep.setPortResolver(new MockPortResolver(8080, 8443)); ep.commence(request, response, null); - assertThat(response.getRedirectedUrl()).isEqualTo("https://www.example.com:8443/bigWebApp/hello"); + assertThat(response.getRedirectedUrl()).isEqualTo("https://127.0.0.1:8443/bigWebApp/hello"); // Now test an unusual custom HTTP:HTTPS is handled properly request.setServerPort(8888); response = new MockHttpServletResponse(); ep.commence(request, response, null); - assertThat(response.getRedirectedUrl()).isEqualTo("https://www.example.com:8443/bigWebApp/hello"); + assertThat(response.getRedirectedUrl()).isEqualTo("https://127.0.0.1:8443/bigWebApp/hello"); PortMapperImpl portMapper = new PortMapperImpl(); Map map = new HashMap<>(); map.put("8888", "9999"); @@ -110,17 +106,13 @@ public class LoginUrlAuthenticationEntryPointTests { ep.setPortResolver(new MockPortResolver(8888, 9999)); ep.afterPropertiesSet(); ep.commence(request, response, null); - assertThat(response.getRedirectedUrl()).isEqualTo("https://www.example.com:9999/bigWebApp/hello"); + assertThat(response.getRedirectedUrl()).isEqualTo("https://127.0.0.1:9999/bigWebApp/hello"); } @Test public void testHttpsOperationFromOriginalHttpsUrl() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setRequestURI("/some_path"); - request.setScheme("https"); - request.setServerName("www.example.com"); - request.setContextPath("/bigWebApp"); - request.setServerPort(443); + MockHttpServletRequest request = get("https://www.example.com:443").requestUri("/bigWebApp", "/some_path", null) + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); LoginUrlAuthenticationEntryPoint ep = new LoginUrlAuthenticationEntryPoint("/hello"); ep.setPortMapper(new PortMapperImpl()); @@ -149,13 +141,7 @@ public class LoginUrlAuthenticationEntryPointTests { ep.setPortMapper(new PortMapperImpl()); ep.setPortResolver(new MockPortResolver(80, 443)); ep.afterPropertiesSet(); - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setRequestURI("/some_path"); - request.setContextPath("/bigWebApp"); - request.setScheme("http"); - request.setServerName("localhost"); - request.setContextPath("/bigWebApp"); - request.setServerPort(80); + MockHttpServletRequest request = get().requestUri("/bigWebApp", "/some_path", null).build(); MockHttpServletResponse response = new MockHttpServletResponse(); ep.commence(request, response, null); assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/bigWebApp/hello"); @@ -167,13 +153,8 @@ public class LoginUrlAuthenticationEntryPointTests { ep.setPortResolver(new MockPortResolver(8888, 1234)); ep.setForceHttps(true); ep.afterPropertiesSet(); - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setRequestURI("/some_path"); - request.setContextPath("/bigWebApp"); - request.setScheme("http"); - request.setServerName("localhost"); - request.setContextPath("/bigWebApp"); - request.setServerPort(8888); // NB: Port we can't resolve + MockHttpServletRequest request = get("http://localhost:8888").requestUri("/bigWebApp", "/some_path", null) + .build(); // NB: Port we can't resolve MockHttpServletResponse response = new MockHttpServletResponse(); ep.commence(request, response, null); // Response doesn't switch to HTTPS, as we didn't know HTTP port 8888 to HTTP port @@ -186,14 +167,7 @@ public class LoginUrlAuthenticationEntryPointTests { LoginUrlAuthenticationEntryPoint ep = new LoginUrlAuthenticationEntryPoint("/hello"); ep.setUseForward(true); ep.afterPropertiesSet(); - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setRequestURI("/bigWebApp/some_path"); - request.setServletPath("/some_path"); - request.setContextPath("/bigWebApp"); - request.setScheme("http"); - request.setServerName("www.example.com"); - request.setContextPath("/bigWebApp"); - request.setServerPort(80); + MockHttpServletRequest request = get().requestUri("/bigWebApp", "/some_path", null).build(); MockHttpServletResponse response = new MockHttpServletResponse(); ep.commence(request, response, null); assertThat(response.getForwardedUrl()).isEqualTo("/hello"); @@ -205,17 +179,10 @@ public class LoginUrlAuthenticationEntryPointTests { ep.setUseForward(true); ep.setForceHttps(true); ep.afterPropertiesSet(); - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setRequestURI("/bigWebApp/some_path"); - request.setServletPath("/some_path"); - request.setContextPath("/bigWebApp"); - request.setScheme("http"); - request.setServerName("www.example.com"); - request.setContextPath("/bigWebApp"); - request.setServerPort(80); + MockHttpServletRequest request = get("http://127.0.0.1").requestUri("/bigWebApp", "/some_path", null).build(); MockHttpServletResponse response = new MockHttpServletResponse(); ep.commence(request, response, null); - assertThat(response.getRedirectedUrl()).isEqualTo("https://www.example.com/bigWebApp/some_path"); + assertThat(response.getRedirectedUrl()).isEqualTo("https://127.0.0.1/bigWebApp/some_path"); } // SEC-1498 diff --git a/web/src/test/java/org/springframework/security/web/authentication/RequestMatcherDelegatingAuthenticationManagerResolverTests.java b/web/src/test/java/org/springframework/security/web/authentication/RequestMatcherDelegatingAuthenticationManagerResolverTests.java index de9e3e2561..385ff38d50 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/RequestMatcherDelegatingAuthenticationManagerResolverTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/RequestMatcherDelegatingAuthenticationManagerResolverTests.java @@ -28,6 +28,7 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.Mockito.mock; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests for {@link RequestMatcherDelegatingAuthenticationManagerResolverTests} @@ -48,8 +49,7 @@ public class RequestMatcherDelegatingAuthenticationManagerResolverTests { .add(new AntPathRequestMatcher("/two/**"), this.two) .build(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", "/one/location"); - request.setServletPath("/one/location"); + MockHttpServletRequest request = get("/one/location").build(); assertThat(resolver.resolve(request)).isEqualTo(this.one); } diff --git a/web/src/test/java/org/springframework/security/web/authentication/UsernamePasswordAuthenticationFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/UsernamePasswordAuthenticationFilterTests.java index 04f372d71d..d43144f47a 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/UsernamePasswordAuthenticationFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/UsernamePasswordAuthenticationFilterTests.java @@ -39,6 +39,7 @@ import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post; /** * Tests {@link UsernamePasswordAuthenticationFilter}. @@ -128,10 +129,10 @@ public class UsernamePasswordAuthenticationFilterTests { @Test public void testSecurityContextHolderStrategyUsed() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/login"); - request.setServletPath("/login"); - request.addParameter(UsernamePasswordAuthenticationFilter.SPRING_SECURITY_FORM_USERNAME_KEY, "rod"); - request.addParameter(UsernamePasswordAuthenticationFilter.SPRING_SECURITY_FORM_PASSWORD_KEY, "koala"); + MockHttpServletRequest request = post("/login") + .param(UsernamePasswordAuthenticationFilter.SPRING_SECURITY_FORM_USERNAME_KEY, "rod") + .param(UsernamePasswordAuthenticationFilter.SPRING_SECURITY_FORM_PASSWORD_KEY, "koala") + .build(); UsernamePasswordAuthenticationFilter filter = new UsernamePasswordAuthenticationFilter(); filter.setAuthenticationManager(createAuthenticationManager()); SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy()); diff --git a/web/src/test/java/org/springframework/security/web/authentication/logout/LogoutHandlerTests.java b/web/src/test/java/org/springframework/security/web/authentication/logout/LogoutHandlerTests.java index 6039fd27a8..e489b5914e 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/logout/LogoutHandlerTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/logout/LogoutHandlerTests.java @@ -24,6 +24,8 @@ import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.web.firewall.DefaultHttpFirewall; import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post; /** * @author Luke Taylor @@ -39,22 +41,20 @@ public class LogoutHandlerTests { @Test public void testRequiresLogoutUrlWorksWithPathParams() { - MockHttpServletRequest request = new MockHttpServletRequest("POST", "/context/logout;someparam=blah"); + MockHttpServletRequest request = post().requestUri("/context", "/logout;someparam=blah", null) + .queryString("otherparam=blah") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); - request.setContextPath("/context"); - request.setServletPath("/logout;someparam=blah"); - request.setQueryString("otherparam=blah"); DefaultHttpFirewall fw = new DefaultHttpFirewall(); assertThat(this.filter.requiresLogout(fw.getFirewalledRequest(request), response)).isTrue(); } @Test public void testRequiresLogoutUrlWorksWithQueryParams() { - MockHttpServletRequest request = new MockHttpServletRequest("GET", "/context/logout"); - request.setContextPath("/context"); + MockHttpServletRequest request = get().requestUri("/context", "/logout", null) + .queryString("otherparam=blah") + .build(); MockHttpServletResponse response = new MockHttpServletResponse(); - request.setServletPath("/logout"); - request.setQueryString("param=blah"); assertThat(this.filter.requiresLogout(request, response)).isTrue(); } diff --git a/web/src/test/java/org/springframework/security/web/authentication/ott/GenerateOneTimeTokenFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/ott/GenerateOneTimeTokenFilterTests.java index 4c7085ab70..0f32977543 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/ott/GenerateOneTimeTokenFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/ott/GenerateOneTimeTokenFilterTests.java @@ -38,6 +38,7 @@ import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.post; /** * Tests for {@link GenerateOneTimeTokenWebFilter} @@ -55,7 +56,7 @@ public class GenerateOneTimeTokenFilterTests { private static final String USERNAME = "user"; - private final MockHttpServletRequest request = new MockHttpServletRequest(); + private MockHttpServletRequest request; private final MockHttpServletResponse response = new MockHttpServletResponse(); @@ -63,9 +64,7 @@ public class GenerateOneTimeTokenFilterTests { @BeforeEach void setup() { - this.request.setMethod("POST"); - this.request.setServletPath("/ott/generate"); - this.request.setRequestURI("/ott/generate"); + this.request = post("/ott/generate").build(); } @Test @@ -87,6 +86,7 @@ public class GenerateOneTimeTokenFilterTests { void filterWhenUsernameFormParamIsEmptyThenNull() throws ServletException, IOException { given(this.oneTimeTokenService.generate(ArgumentMatchers.any(GenerateOneTimeTokenRequest.class))) .willReturn((new DefaultOneTimeToken(TOKEN, USERNAME, Instant.now()))); + GenerateOneTimeTokenFilter filter = new GenerateOneTimeTokenFilter(this.oneTimeTokenService, this.successHandler); diff --git a/web/src/test/java/org/springframework/security/web/authentication/ui/DefaultOneTimeTokenSubmitPageGeneratingFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/ui/DefaultOneTimeTokenSubmitPageGeneratingFilterTests.java index ad38fa6f7c..18f9af5b6f 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/ui/DefaultOneTimeTokenSubmitPageGeneratingFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/ui/DefaultOneTimeTokenSubmitPageGeneratingFilterTests.java @@ -27,6 +27,7 @@ import org.springframework.mock.web.MockHttpServletResponse; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests for {@link DefaultOneTimeTokenSubmitPageGeneratingFilter} @@ -37,7 +38,7 @@ class DefaultOneTimeTokenSubmitPageGeneratingFilterTests { DefaultOneTimeTokenSubmitPageGeneratingFilter filter = new DefaultOneTimeTokenSubmitPageGeneratingFilter(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", "/login/ott"); + MockHttpServletRequest request; MockHttpServletResponse response = new MockHttpServletResponse(); @@ -45,9 +46,7 @@ class DefaultOneTimeTokenSubmitPageGeneratingFilterTests { @BeforeEach void setup() { - this.request.setMethod("GET"); - this.request.setServletPath("/login/ott"); - this.request.setRequestURI("/login/ott"); + this.request = get("/login/ott").build(); } @Test @@ -80,10 +79,9 @@ class DefaultOneTimeTokenSubmitPageGeneratingFilterTests { @Test void setContextThenGenerates() throws Exception { - this.request.setContextPath("/context"); - this.request.setRequestURI("/context/login/ott"); + MockHttpServletRequest request = get().requestUri("/context", "/login/ott", null).build(); this.filter.setLoginProcessingUrl("/login/another"); - this.filter.doFilterInternal(this.request, this.response, this.filterChain); + this.filter.doFilterInternal(request, this.response, this.filterChain); String response = this.response.getContentAsString(); assertThat(response).contains("
"); } @@ -101,7 +99,7 @@ class DefaultOneTimeTokenSubmitPageGeneratingFilterTests { void filterThenRenders() throws Exception { this.request.setParameter("token", "this<>!@#\""); this.filter.setLoginProcessingUrl("/login/another"); - this.filter.setResolveHiddenInputs((request) -> Map.of("_csrf", "csrf-token-value")); + this.filter.setResolveHiddenInputs((r) -> Map.of("_csrf", "csrf-token-value")); this.filter.doFilterInternal(this.request, this.response, this.filterChain); String response = this.response.getContentAsString(); assertThat(response).isEqualTo( diff --git a/web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java index faf9f17db4..d2e9c6806e 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java @@ -61,6 +61,7 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests {@link BasicAuthenticationFilter}. @@ -94,8 +95,7 @@ public class BasicAuthenticationFilterTests { @Test public void testFilterIgnoresRequestsContainingNoAuthorizationHeader() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath("/some_file.html"); + MockHttpServletRequest request = get("/some_file.html").build(); final MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); this.filter.doFilter(request, response, chain); @@ -113,9 +113,8 @@ public class BasicAuthenticationFilterTests { @Test public void testInvalidBasicAuthorizationTokenIsIgnored() throws Exception { String token = "NOT_A_VALID_TOKEN_AS_MISSING_COLON"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); - request.setServletPath("/some_file.html"); request.setSession(new MockHttpSession()); final MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); @@ -127,9 +126,8 @@ public class BasicAuthenticationFilterTests { @Test public void invalidBase64IsIgnored() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "Basic NOT_VALID_BASE64"); - request.setServletPath("/some_file.html"); request.setSession(new MockHttpSession()); final MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); @@ -143,9 +141,8 @@ public class BasicAuthenticationFilterTests { @Test public void testNormalOperation() throws Exception { String token = "rod:koala"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); - request.setServletPath("/some_file.html"); // Test assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); FilterChain chain = mock(FilterChain.class); @@ -172,9 +169,8 @@ public class BasicAuthenticationFilterTests { @Test public void doFilterWhenSchemeLowercaseThenCaseInsensitveMatchWorks() throws Exception { String token = "rod:koala"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "basic " + CodecTestUtils.encodeBase64(token)); - request.setServletPath("/some_file.html"); // Test assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); FilterChain chain = mock(FilterChain.class); @@ -187,9 +183,8 @@ public class BasicAuthenticationFilterTests { @Test public void doFilterWhenSchemeMixedCaseThenCaseInsensitiveMatchWorks() throws Exception { String token = "rod:koala"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "BaSiC " + CodecTestUtils.encodeBase64(token)); - request.setServletPath("/some_file.html"); assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); FilterChain chain = mock(FilterChain.class); this.filter.doFilter(request, new MockHttpServletResponse(), chain); @@ -200,9 +195,8 @@ public class BasicAuthenticationFilterTests { @Test public void testOtherAuthorizationSchemeIsIgnored() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "SOME_OTHER_AUTHENTICATION_SCHEME"); - request.setServletPath("/some_file.html"); FilterChain chain = mock(FilterChain.class); this.filter.doFilter(request, new MockHttpServletResponse(), chain); verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class)); @@ -222,9 +216,8 @@ public class BasicAuthenticationFilterTests { @Test public void testSuccessLoginThenFailureLoginResultsInSessionLosingToken() throws Exception { String token = "rod:koala"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); - request.setServletPath("/some_file.html"); final MockHttpServletResponse response1 = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); this.filter.doFilter(request, response1, chain); @@ -240,7 +233,6 @@ public class BasicAuthenticationFilterTests { chain = mock(FilterChain.class); this.filter.doFilter(request, response2, chain); verify(chain, never()).doFilter(any(ServletRequest.class), any(ServletResponse.class)); - request.setServletPath("/some_file.html"); // Test - the filter chain will not be invoked, as we get a 401 forbidden response MockHttpServletResponse response = response2; assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); @@ -250,9 +242,8 @@ public class BasicAuthenticationFilterTests { @Test public void testWrongPasswordContinuesFilterChainIfIgnoreFailureIsTrue() throws Exception { String token = "rod:WRONG_PASSWORD"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); - request.setServletPath("/some_file.html"); request.setSession(new MockHttpSession()); this.filter = new BasicAuthenticationFilter(this.manager); assertThat(this.filter.isIgnoreFailure()).isTrue(); @@ -266,9 +257,8 @@ public class BasicAuthenticationFilterTests { @Test public void testWrongPasswordReturnsForbiddenIfIgnoreFailureIsFalse() throws Exception { String token = "rod:WRONG_PASSWORD"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); - request.setServletPath("/some_file.html"); request.setSession(new MockHttpSession()); assertThat(this.filter.isIgnoreFailure()).isFalse(); final MockHttpServletResponse response = new MockHttpServletResponse(); @@ -284,9 +274,8 @@ public class BasicAuthenticationFilterTests { @Test public void skippedOnErrorDispatch() throws Exception { String token = "bad:credentials"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); - request.setServletPath("/some_file.html"); request.setAttribute(WebUtils.ERROR_REQUEST_URI_ATTRIBUTE, "/error"); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); @@ -307,10 +296,9 @@ public class BasicAuthenticationFilterTests { given(this.manager.authenticate(not(eq(rodRequest)))).willThrow(new BadCredentialsException("")); this.filter = new BasicAuthenticationFilter(this.manager, new BasicAuthenticationEntryPoint()); String token = "rod:äöü"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token.getBytes(StandardCharsets.UTF_8))); - request.setServletPath("/some_file.html"); MockHttpServletResponse response = new MockHttpServletResponse(); // Test assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); @@ -336,10 +324,9 @@ public class BasicAuthenticationFilterTests { this.filter = new BasicAuthenticationFilter(this.manager, new BasicAuthenticationEntryPoint()); this.filter.setCredentialsCharset("ISO-8859-1"); String token = "rod:äöü"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token.getBytes(StandardCharsets.ISO_8859_1))); - request.setServletPath("/some_file.html"); MockHttpServletResponse response = new MockHttpServletResponse(); // Test assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); @@ -367,10 +354,9 @@ public class BasicAuthenticationFilterTests { this.filter = new BasicAuthenticationFilter(this.manager, new BasicAuthenticationEntryPoint()); this.filter.setCredentialsCharset("ISO-8859-1"); String token = "rod:äöü"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token.getBytes(StandardCharsets.UTF_8))); - request.setServletPath("/some_file.html"); MockHttpServletResponse response = new MockHttpServletResponse(); // Test assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); @@ -383,9 +369,8 @@ public class BasicAuthenticationFilterTests { @Test public void requestWhenEmptyBasicAuthorizationHeaderTokenThenUnauthorized() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "Basic "); - request.setServletPath("/some_file.html"); request.setSession(new MockHttpSession()); final MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); @@ -401,9 +386,8 @@ public class BasicAuthenticationFilterTests { SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class); this.filter.setSecurityContextRepository(securityContextRepository); String token = "rod:koala"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/some_file.html").build(); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); - request.setServletPath("/some_file.html"); MockHttpServletResponse response = new MockHttpServletResponse(); // Test assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); @@ -496,9 +480,8 @@ public class BasicAuthenticationFilterTests { public void doFilterWhenCustomAuthenticationConverterThatIgnoresRequestThenIgnores() throws Exception { this.filter.setAuthenticationConverter(new TestAuthenticationConverter()); String token = "rod:koala"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/ignored").build(); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); - request.setServletPath("/ignored"); FilterChain filterChain = mock(FilterChain.class); MockHttpServletResponse response = new MockHttpServletResponse(); this.filter.doFilter(request, response, filterChain); @@ -513,9 +496,8 @@ public class BasicAuthenticationFilterTests { public void doFilterWhenCustomAuthenticationConverterRequestThenAuthenticate() throws Exception { this.filter.setAuthenticationConverter(new TestAuthenticationConverter()); String token = "rod:koala"; - MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletRequest request = get("/ok").build(); request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token)); - request.setServletPath("/ok"); FilterChain filterChain = mock(FilterChain.class); MockHttpServletResponse response = new MockHttpServletResponse(); this.filter.doFilter(request, response, filterChain); diff --git a/web/src/test/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilterTests.java index 230c554fcd..c3375abf29 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilterTests.java @@ -53,6 +53,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * Tests {@link DigestAuthenticationFilter}. @@ -131,8 +132,7 @@ public class DigestAuthenticationFilterTests { this.filter = new DigestAuthenticationFilter(); this.filter.setUserDetailsService(uds); this.filter.setAuthenticationEntryPoint(ep); - this.request = new MockHttpServletRequest("GET", REQUEST_URI); - this.request.setServletPath(REQUEST_URI); + this.request = get(REQUEST_URI).build(); } @Test diff --git a/web/src/test/java/org/springframework/security/web/debug/DebugFilterTests.java b/web/src/test/java/org/springframework/security/web/debug/DebugFilterTests.java index 683db01360..1088f993e5 100644 --- a/web/src/test/java/org/springframework/security/web/debug/DebugFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/debug/DebugFilterTests.java @@ -41,6 +41,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * @author Rob Winch @@ -120,10 +121,7 @@ public class DebugFilterTests { @Test public void doFilterLogsProperly() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setMethod("GET"); - request.setServletPath("/path"); - request.setPathInfo("/"); + MockHttpServletRequest request = get().requestUri(null, "/path", "/").build(); request.addHeader("A", "A Value"); request.addHeader("A", "Another Value"); request.addHeader("B", "B Value"); diff --git a/web/src/test/java/org/springframework/security/web/firewall/DefaultHttpFirewallTests.java b/web/src/test/java/org/springframework/security/web/firewall/DefaultHttpFirewallTests.java index c7ce4d72f3..00f656816b 100644 --- a/web/src/test/java/org/springframework/security/web/firewall/DefaultHttpFirewallTests.java +++ b/web/src/test/java/org/springframework/security/web/firewall/DefaultHttpFirewallTests.java @@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test; import org.springframework.mock.web.MockHttpServletRequest; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * @author Luke Taylor @@ -34,8 +35,7 @@ public class DefaultHttpFirewallTests { public void unnormalizedPathsAreRejected() { DefaultHttpFirewall fw = new DefaultHttpFirewall(); for (String path : this.unnormalizedPaths) { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setServletPath(path); + MockHttpServletRequest request = get().requestUri(path).build(); assertThatExceptionOfType(RequestRejectedException.class) .isThrownBy(() -> fw.getFirewalledRequest(request)); request.setPathInfo(path); diff --git a/web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java b/web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java index f5406f0094..e7a46bfe54 100644 --- a/web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java +++ b/web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java @@ -27,6 +27,7 @@ import org.springframework.mock.web.MockHttpServletRequest; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; /** * @author Rob Winch @@ -112,8 +113,7 @@ public class StrictHttpFirewallTests { @Test public void getFirewalledRequestWhenServletPathNotNormalizedThenThrowsRequestRejectedException() { for (String path : this.unnormalizedPaths) { - this.request = new MockHttpServletRequest("GET", ""); - this.request.setServletPath(path); + this.request = get().requestUri(path).build(); assertThatExceptionOfType(RequestRejectedException.class) .isThrownBy(() -> this.firewall.getFirewalledRequest(this.request)); } diff --git a/web/src/test/java/org/springframework/security/web/util/matcher/RegexRequestMatcherTests.java b/web/src/test/java/org/springframework/security/web/util/matcher/RegexRequestMatcherTests.java index 8263b776f0..9671303750 100644 --- a/web/src/test/java/org/springframework/security/web/util/matcher/RegexRequestMatcherTests.java +++ b/web/src/test/java/org/springframework/security/web/util/matcher/RegexRequestMatcherTests.java @@ -28,6 +28,7 @@ import org.springframework.mock.web.MockHttpServletRequest; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.BDDMockito.given; +import static org.springframework.security.web.servlet.TestMockHttpServletRequests.get; import static org.springframework.security.web.util.matcher.RegexRequestMatcher.regexMatcher; /** @@ -50,8 +51,7 @@ public class RegexRequestMatcherTests { @Test public void matchesIfHttpMethodAndPathMatch() { RegexRequestMatcher matcher = new RegexRequestMatcher(".*", "GET"); - MockHttpServletRequest request = new MockHttpServletRequest("GET", "/anything"); - request.setServletPath("/anything"); + MockHttpServletRequest request = get("/anything").build(); assertThat(matcher.matches(request)).isTrue(); }