formLogin() and login() implement Mergable

This is necessary so that default requests like Spring REST Docs work.

Closes gh-7572
This commit is contained in:
Dávid Kovács 2020-05-09 21:53:33 +02:00 committed by Rob Winch
parent bff6d82dd0
commit fa9898dd6d
4 changed files with 141 additions and 16 deletions

View File

@ -15,16 +15,18 @@
*/
package org.springframework.security.test.web.servlet.request;
import javax.servlet.ServletContext;
import org.springframework.beans.Mergeable;
import org.springframework.http.MediaType;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.RequestBuilder;
import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder;
import org.springframework.test.web.servlet.request.RequestPostProcessor;
import org.springframework.web.util.UriComponentsBuilder;
import javax.servlet.ServletContext;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
@ -86,15 +88,23 @@ public final class SecurityMockMvcRequestBuilders {
* @author Rob Winch
* @since 4.0
*/
public static final class LogoutRequestBuilder implements RequestBuilder {
public static final class LogoutRequestBuilder implements RequestBuilder, Mergeable {
private String logoutUrl = "/logout";
private RequestPostProcessor postProcessor = csrf();
private Mergeable parent;
@Override
public MockHttpServletRequest buildRequest(ServletContext servletContext) {
MockHttpServletRequest request = post(this.logoutUrl)
.accept(MediaType.TEXT_HTML, MediaType.ALL)
.buildRequest(servletContext);
MockHttpServletRequestBuilder logoutRequest = post(this.logoutUrl)
.accept(MediaType.TEXT_HTML, MediaType.ALL);
if (this.parent != null) {
logoutRequest = (MockHttpServletRequestBuilder) logoutRequest.merge(this.parent);
}
MockHttpServletRequest request = logoutRequest.buildRequest(servletContext);
logoutRequest.postProcessRequest(request);
return this.postProcessor.postProcessRequest(request);
}
@ -122,6 +132,24 @@ public final class SecurityMockMvcRequestBuilders {
return this;
}
@Override
public boolean isMergeEnabled() {
return true;
}
@Override
public Object merge(Object parent) {
if (parent == null) {
return this;
}
if (parent instanceof Mergeable) {
this.parent = (Mergeable) parent;
return this;
} else {
throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]");
}
}
private LogoutRequestBuilder() {
}
}
@ -132,22 +160,31 @@ public final class SecurityMockMvcRequestBuilders {
* @author Rob Winch
* @since 4.0
*/
public static final class FormLoginRequestBuilder implements RequestBuilder {
public static final class FormLoginRequestBuilder implements RequestBuilder, Mergeable {
private String usernameParam = "username";
private String passwordParam = "password";
private String username = "user";
private String password = "password";
private String loginProcessingUrl = "/login";
private MediaType acceptMediaType = MediaType.APPLICATION_FORM_URLENCODED;
private Mergeable parent;
private RequestPostProcessor postProcessor = csrf();
@Override
public MockHttpServletRequest buildRequest(ServletContext servletContext) {
MockHttpServletRequest request = post(this.loginProcessingUrl)
.accept(this.acceptMediaType).param(this.usernameParam, this.username)
.param(this.passwordParam, this.password)
.buildRequest(servletContext);
MockHttpServletRequestBuilder loginRequest = post(this.loginProcessingUrl)
.accept(this.acceptMediaType)
.param(this.usernameParam, this.username)
.param(this.passwordParam, this.password);
if (this.parent != null) {
loginRequest = (MockHttpServletRequestBuilder) loginRequest.merge(this.parent);
}
MockHttpServletRequest request = loginRequest.buildRequest(servletContext);
loginRequest.postProcessRequest(request);
return this.postProcessor.postProcessRequest(request);
}
@ -258,6 +295,24 @@ public final class SecurityMockMvcRequestBuilders {
return this;
}
@Override
public boolean isMergeEnabled() {
return true;
}
@Override
public Object merge(Object parent) {
if (parent == null) {
return this;
}
if (parent instanceof Mergeable ) {
this.parent = (Mergeable) parent;
return this;
} else {
throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]");
}
}
private FormLoginRequestBuilder() {
}
}

View File

@ -410,7 +410,7 @@ public final class SecurityMockMvcRequestPostProcessors {
private final CsrfTokenRepository delegate;
private TestCsrfTokenRepository(CsrfTokenRepository delegate) {
TestCsrfTokenRepository(CsrfTokenRepository delegate) {
this.delegate = delegate;
}

View File

@ -17,14 +17,25 @@ package org.springframework.security.test.web.servlet.request;
import org.junit.Before;
import org.junit.Test;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockServletContext;
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
import org.springframework.test.web.servlet.request.RequestPostProcessor;
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
import java.util.Arrays;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.powermock.api.mockito.PowerMockito.when;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin;
public class SecurityMockMvcRequestBuildersFormLoginTests {
@ -82,6 +93,31 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
assertThat(request.getRequestURI()).isEqualTo("/uri-login/val1/val2");
}
/**
* spring-restdocs uses postprocessors to do its trick. It will work only if these are merged together
* with our request builders. (gh-7572)
* @throws Exception
*/
@Test
public void postProcessorsAreMergedDuringMockMvcPerform() throws Exception {
RequestPostProcessor postProcessor = mock(RequestPostProcessor.class);
when(postProcessor.postProcessRequest(any())).thenAnswer(i -> i.getArgument(0));
MockMvc mockMvc = MockMvcBuilders.standaloneSetup(new Object())
.defaultRequest(MockMvcRequestBuilders.get("/").with(postProcessor))
.build();
MvcResult mvcResult = mockMvc.perform(formLogin()).andReturn();
assertThat(mvcResult.getRequest().getMethod()).isEqualTo(HttpMethod.POST.name());
assertThat(mvcResult.getRequest().getHeader("Accept"))
.isEqualTo(MediaType.toString(Arrays.asList(MediaType.APPLICATION_FORM_URLENCODED)));
assertThat(mvcResult.getRequest().getParameter("username")).isEqualTo("user");
assertThat(mvcResult.getRequest().getParameter("password")).isEqualTo("password");
assertThat(mvcResult.getRequest().getRequestURI()).isEqualTo("/login");
assertThat(mvcResult.getRequest().getParameter("_csrf")).isNotEmpty();
verify(postProcessor).postProcessRequest(any());
}
// gh-3920
@Test
public void usesAcceptMediaForContentNegotiation() {

View File

@ -15,15 +15,28 @@
*/
package org.springframework.security.test.web.servlet.request;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.logout;
import org.junit.Before;
import org.junit.Test;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockServletContext;
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
import org.springframework.test.web.servlet.request.RequestPostProcessor;
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
import java.util.Arrays;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.powermock.api.mockito.PowerMockito.when;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.logout;
public class SecurityMockMvcRequestBuildersFormLogoutTests {
private MockServletContext servletContext;
@ -71,4 +84,25 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests {
assertThat(request.getRequestURI()).isEqualTo("/uri-logout/val1/val2");
}
/**
* spring-restdocs uses postprocessors to do its trick. It will work only if these are merged together
* with our request builders. (gh-7572)
* @throws Exception
*/
@Test
public void postProcessorsAreMergedDuringMockMvcPerform() throws Exception {
RequestPostProcessor postProcessor = mock(RequestPostProcessor.class);
when(postProcessor.postProcessRequest(any())).thenAnswer(i -> i.getArgument(0));
MockMvc mockMvc = MockMvcBuilders.standaloneSetup(new Object())
.defaultRequest(MockMvcRequestBuilders.get("/").with(postProcessor))
.build();
MvcResult mvcResult = mockMvc.perform(logout()).andReturn();
assertThat(mvcResult.getRequest().getMethod()).isEqualTo(HttpMethod.POST.name());
assertThat(mvcResult.getRequest().getHeader("Accept"))
.isEqualTo(MediaType.toString(Arrays.asList(MediaType.TEXT_HTML, MediaType.ALL)));
assertThat(mvcResult.getRequest().getRequestURI()).isEqualTo("/logout");
assertThat(mvcResult.getRequest().getParameter("_csrf")).isNotEmpty();
verify(postProcessor).postProcessRequest(any());
}
}