From fa9898dd6df87f581f7e6ee6d99704c3e0dd8554 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=A1vid=20Kov=C3=A1cs?= Date: Sat, 9 May 2020 21:53:33 +0200 Subject: [PATCH] formLogin() and login() implement Mergable This is necessary so that default requests like Spring REST Docs work. Closes gh-7572 --- .../SecurityMockMvcRequestBuilders.java | 77 ++++++++++++++++--- .../SecurityMockMvcRequestPostProcessors.java | 2 +- ...yMockMvcRequestBuildersFormLoginTests.java | 38 ++++++++- ...MockMvcRequestBuildersFormLogoutTests.java | 40 +++++++++- 4 files changed, 141 insertions(+), 16 deletions(-) diff --git a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuilders.java b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuilders.java index 451f76fcd5..fb5a69d24d 100644 --- a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuilders.java +++ b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuilders.java @@ -15,16 +15,18 @@ */ package org.springframework.security.test.web.servlet.request; -import javax.servlet.ServletContext; - +import org.springframework.beans.Mergeable; import org.springframework.http.MediaType; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.RequestBuilder; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.test.web.servlet.request.RequestPostProcessor; import org.springframework.web.util.UriComponentsBuilder; +import javax.servlet.ServletContext; + import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; @@ -86,15 +88,23 @@ public final class SecurityMockMvcRequestBuilders { * @author Rob Winch * @since 4.0 */ - public static final class LogoutRequestBuilder implements RequestBuilder { + public static final class LogoutRequestBuilder implements RequestBuilder, Mergeable { private String logoutUrl = "/logout"; private RequestPostProcessor postProcessor = csrf(); + private Mergeable parent; @Override public MockHttpServletRequest buildRequest(ServletContext servletContext) { - MockHttpServletRequest request = post(this.logoutUrl) - .accept(MediaType.TEXT_HTML, MediaType.ALL) - .buildRequest(servletContext); + MockHttpServletRequestBuilder logoutRequest = post(this.logoutUrl) + .accept(MediaType.TEXT_HTML, MediaType.ALL); + + if (this.parent != null) { + logoutRequest = (MockHttpServletRequestBuilder) logoutRequest.merge(this.parent); + } + + MockHttpServletRequest request = logoutRequest.buildRequest(servletContext); + logoutRequest.postProcessRequest(request); + return this.postProcessor.postProcessRequest(request); } @@ -122,6 +132,24 @@ public final class SecurityMockMvcRequestBuilders { return this; } + @Override + public boolean isMergeEnabled() { + return true; + } + + @Override + public Object merge(Object parent) { + if (parent == null) { + return this; + } + if (parent instanceof Mergeable) { + this.parent = (Mergeable) parent; + return this; + } else { + throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]"); + } + } + private LogoutRequestBuilder() { } } @@ -132,22 +160,31 @@ public final class SecurityMockMvcRequestBuilders { * @author Rob Winch * @since 4.0 */ - public static final class FormLoginRequestBuilder implements RequestBuilder { + public static final class FormLoginRequestBuilder implements RequestBuilder, Mergeable { private String usernameParam = "username"; private String passwordParam = "password"; private String username = "user"; private String password = "password"; private String loginProcessingUrl = "/login"; private MediaType acceptMediaType = MediaType.APPLICATION_FORM_URLENCODED; + private Mergeable parent; private RequestPostProcessor postProcessor = csrf(); @Override public MockHttpServletRequest buildRequest(ServletContext servletContext) { - MockHttpServletRequest request = post(this.loginProcessingUrl) - .accept(this.acceptMediaType).param(this.usernameParam, this.username) - .param(this.passwordParam, this.password) - .buildRequest(servletContext); + MockHttpServletRequestBuilder loginRequest = post(this.loginProcessingUrl) + .accept(this.acceptMediaType) + .param(this.usernameParam, this.username) + .param(this.passwordParam, this.password); + + if (this.parent != null) { + loginRequest = (MockHttpServletRequestBuilder) loginRequest.merge(this.parent); + } + + MockHttpServletRequest request = loginRequest.buildRequest(servletContext); + loginRequest.postProcessRequest(request); + return this.postProcessor.postProcessRequest(request); } @@ -258,6 +295,24 @@ public final class SecurityMockMvcRequestBuilders { return this; } + @Override + public boolean isMergeEnabled() { + return true; + } + + @Override + public Object merge(Object parent) { + if (parent == null) { + return this; + } + if (parent instanceof Mergeable ) { + this.parent = (Mergeable) parent; + return this; + } else { + throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]"); + } + } + private FormLoginRequestBuilder() { } } diff --git a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java index 017567930a..aeb55699b4 100644 --- a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java +++ b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java @@ -410,7 +410,7 @@ public final class SecurityMockMvcRequestPostProcessors { private final CsrfTokenRepository delegate; - private TestCsrfTokenRepository(CsrfTokenRepository delegate) { + TestCsrfTokenRepository(CsrfTokenRepository delegate) { this.delegate = delegate; } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java index 0a2501449a..3a00f6a080 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java @@ -17,14 +17,25 @@ package org.springframework.security.test.web.servlet.request; import org.junit.Before; import org.junit.Test; - +import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockServletContext; import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor; import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.MvcResult; +import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; +import org.springframework.test.web.servlet.request.RequestPostProcessor; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; + +import java.util.Arrays; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.powermock.api.mockito.PowerMockito.when; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin; public class SecurityMockMvcRequestBuildersFormLoginTests { @@ -82,6 +93,31 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { assertThat(request.getRequestURI()).isEqualTo("/uri-login/val1/val2"); } + /** + * spring-restdocs uses postprocessors to do its trick. It will work only if these are merged together + * with our request builders. (gh-7572) + * @throws Exception + */ + @Test + public void postProcessorsAreMergedDuringMockMvcPerform() throws Exception { + RequestPostProcessor postProcessor = mock(RequestPostProcessor.class); + when(postProcessor.postProcessRequest(any())).thenAnswer(i -> i.getArgument(0)); + MockMvc mockMvc = MockMvcBuilders.standaloneSetup(new Object()) + .defaultRequest(MockMvcRequestBuilders.get("/").with(postProcessor)) + .build(); + + + MvcResult mvcResult = mockMvc.perform(formLogin()).andReturn(); + assertThat(mvcResult.getRequest().getMethod()).isEqualTo(HttpMethod.POST.name()); + assertThat(mvcResult.getRequest().getHeader("Accept")) + .isEqualTo(MediaType.toString(Arrays.asList(MediaType.APPLICATION_FORM_URLENCODED))); + assertThat(mvcResult.getRequest().getParameter("username")).isEqualTo("user"); + assertThat(mvcResult.getRequest().getParameter("password")).isEqualTo("password"); + assertThat(mvcResult.getRequest().getRequestURI()).isEqualTo("/login"); + assertThat(mvcResult.getRequest().getParameter("_csrf")).isNotEmpty(); + verify(postProcessor).postProcessRequest(any()); + } + // gh-3920 @Test public void usesAcceptMediaForContentNegotiation() { diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java index 1e86868d3e..b6271e9409 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java @@ -15,15 +15,28 @@ */ package org.springframework.security.test.web.servlet.request; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.logout; - import org.junit.Before; import org.junit.Test; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockServletContext; import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor; import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.MvcResult; +import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; +import org.springframework.test.web.servlet.request.RequestPostProcessor; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; + +import java.util.Arrays; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.powermock.api.mockito.PowerMockito.when; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.logout; public class SecurityMockMvcRequestBuildersFormLogoutTests { private MockServletContext servletContext; @@ -71,4 +84,25 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests { assertThat(request.getRequestURI()).isEqualTo("/uri-logout/val1/val2"); } + /** + * spring-restdocs uses postprocessors to do its trick. It will work only if these are merged together + * with our request builders. (gh-7572) + * @throws Exception + */ + @Test + public void postProcessorsAreMergedDuringMockMvcPerform() throws Exception { + RequestPostProcessor postProcessor = mock(RequestPostProcessor.class); + when(postProcessor.postProcessRequest(any())).thenAnswer(i -> i.getArgument(0)); + MockMvc mockMvc = MockMvcBuilders.standaloneSetup(new Object()) + .defaultRequest(MockMvcRequestBuilders.get("/").with(postProcessor)) + .build(); + + MvcResult mvcResult = mockMvc.perform(logout()).andReturn(); + assertThat(mvcResult.getRequest().getMethod()).isEqualTo(HttpMethod.POST.name()); + assertThat(mvcResult.getRequest().getHeader("Accept")) + .isEqualTo(MediaType.toString(Arrays.asList(MediaType.TEXT_HTML, MediaType.ALL))); + assertThat(mvcResult.getRequest().getRequestURI()).isEqualTo("/logout"); + assertThat(mvcResult.getRequest().getParameter("_csrf")).isNotEmpty(); + verify(postProcessor).postProcessRequest(any()); + } }