Fix test .standaloneSetup

Previously, Spring Security's test support did not work well with the
standalone setup. This was because the springSecurityFilterChain was not
found by the WebTestUtils.

This commit ensures that the springSecurityFilterChain is added as a
servlet attribute if it is explicitly defined. WebTestUtils can then
find the springSecurityFilterChain in the ServletContext.

Fixes gh-3881
This commit is contained in:
Rob Winch 2016-05-13 12:57:06 -05:00 committed by Joe Grandja
parent 602bb457b8
commit 7b61a44929
5 changed files with 101 additions and 12 deletions

View File

@ -68,6 +68,8 @@ final class SecurityMockMvcConfigurer extends MockMvcConfigurerAdapter {
}
builder.addFilters(this.springSecurityFilterChain);
context.getServletContext().setAttribute(BeanIds.SPRING_SECURITY_FILTER_CHAIN,
this.springSecurityFilterChain);
return testSecurityContext();
}

View File

@ -18,9 +18,11 @@ package org.springframework.security.test.web.support;
import java.util.List;
import javax.servlet.Filter;
import javax.servlet.ServletContext;
import javax.servlet.http.HttpServletRequest;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.security.config.BeanIds;
import org.springframework.security.web.context.AbstractSecurityWebApplicationInitializer;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextPersistenceFilter;
@ -115,18 +117,9 @@ public abstract class WebTestUtils {
@SuppressWarnings("unchecked")
static <T extends Filter> T findFilter(HttpServletRequest request,
Class<T> filterClass) {
WebApplicationContext webApplicationContext = WebApplicationContextUtils
.getWebApplicationContext(request.getServletContext());
if (webApplicationContext == null) {
return null;
}
Filter springSecurityFilterChain = null;
try {
springSecurityFilterChain = webApplicationContext.getBean(
AbstractSecurityWebApplicationInitializer.DEFAULT_FILTER_NAME,
Filter.class);
}
catch (NoSuchBeanDefinitionException notFound) {
ServletContext servletContext = request.getServletContext();
Filter springSecurityFilterChain = getSpringSecurityFilterChain(servletContext);
if (springSecurityFilterChain == null) {
return null;
}
List<Filter> filters = (List<Filter>) ReflectionTestUtils
@ -142,6 +135,26 @@ public abstract class WebTestUtils {
return null;
}
private static Filter getSpringSecurityFilterChain(ServletContext servletContext) {
Filter result = (Filter) servletContext
.getAttribute(BeanIds.SPRING_SECURITY_FILTER_CHAIN);
if (result != null) {
return result;
}
WebApplicationContext webApplicationContext = WebApplicationContextUtils
.getWebApplicationContext(servletContext);
if (webApplicationContext != null) {
try {
return webApplicationContext.getBean(
AbstractSecurityWebApplicationInitializer.DEFAULT_FILTER_NAME,
Filter.class);
}
catch (NoSuchBeanDefinitionException notFound) {
}
}
return null;
}
private WebTestUtils() {
}
}

View File

@ -27,6 +27,7 @@ import javax.servlet.http.HttpSession;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.mock.web.MockHttpServletRequest;
@ -34,6 +35,8 @@ import org.springframework.mock.web.MockHttpSession;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessorsCsrfTests.Config.TheController;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import org.springframework.test.context.web.WebAppConfiguration;
@ -58,6 +61,10 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.
public class SecurityMockMvcRequestPostProcessorsCsrfTests {
@Autowired
WebApplicationContext wac;
@Autowired
TheController controller;
@Autowired
FilterChainProxy springSecurityFilterChain;
MockMvc mockMvc;
@ -69,7 +76,20 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests {
.apply(springSecurity())
.build();
// @formatter:on
}
// gh-3881
@Test
public void csrfWithStandalone() throws Exception {
// @formatter:off
this.mockMvc = MockMvcBuilders
.standaloneSetup(this.controller)
.apply(springSecurity(this.springSecurityFilterChain))
.build();
this.mockMvc.perform(post("/").with(csrf()))
.andExpect(status().is2xxSuccessful())
.andExpect(csrfAsParam());
// @formatter:on
}
@Test

View File

@ -16,12 +16,15 @@
package org.springframework.security.test.web.servlet.setup;
import javax.servlet.Filter;
import javax.servlet.ServletContext;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import org.springframework.security.config.BeanIds;
import org.springframework.test.web.servlet.setup.ConfigurableMockMvcBuilder;
import org.springframework.web.context.WebApplicationContext;
@ -40,6 +43,13 @@ public class SecurityMockMvcConfigurerTests {
private ConfigurableMockMvcBuilder<?> builder;
@Mock
private WebApplicationContext context;
@Mock
private ServletContext servletContext;
@Before
public void setup() {
when(this.context.getServletContext()).thenReturn(this.servletContext);
}
@Test
public void beforeMockMvcCreatedOverrideBean() throws Exception {
@ -49,6 +59,8 @@ public class SecurityMockMvcConfigurerTests {
configurer.beforeMockMvcCreated(this.builder, this.context);
verify(this.builder).addFilters(this.filter);
verify(this.servletContext).setAttribute(BeanIds.SPRING_SECURITY_FILTER_CHAIN,
this.filter);
}
@Test

View File

@ -25,14 +25,19 @@ import org.mockito.runners.MockitoJUnitRunner;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.annotation.Configuration;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.config.BeanIds;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.web.DefaultSecurityFilterChain;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextPersistenceFilter;
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.security.web.util.matcher.AnyRequestMatcher;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
@ -129,6 +134,34 @@ public class WebTestUtilsTests {
SecurityContextPersistenceFilter.class)).isNull();
}
@Test
public void findFilterNoSpringSecurityFilterChainInContext() {
loadConfig(NoSecurityConfig.class);
CsrfFilter toFind = new CsrfFilter(new HttpSessionCsrfTokenRepository());
FilterChainProxy springSecurityFilterChain = new FilterChainProxy(
new DefaultSecurityFilterChain(AnyRequestMatcher.INSTANCE, toFind));
this.request.getServletContext().setAttribute(
BeanIds.SPRING_SECURITY_FILTER_CHAIN, springSecurityFilterChain);
assertThat(WebTestUtils.findFilter(this.request, toFind.getClass()))
.isEqualTo(toFind);
}
@Test
public void findFilterExplicitWithSecurityFilterInContext() {
loadConfig(SecurityConfigWithDefaults.class);
CsrfFilter toFind = new CsrfFilter(new HttpSessionCsrfTokenRepository());
FilterChainProxy springSecurityFilterChain = new FilterChainProxy(
new DefaultSecurityFilterChain(AnyRequestMatcher.INSTANCE, toFind));
this.request.getServletContext().setAttribute(
BeanIds.SPRING_SECURITY_FILTER_CHAIN, springSecurityFilterChain);
assertThat(WebTestUtils.findFilter(this.request, toFind.getClass()))
.isSameAs(toFind);
}
private void loadConfig(Class<?> config) {
AnnotationConfigWebApplicationContext context = new AnnotationConfigWebApplicationContext();
context.register(config);
@ -180,4 +213,13 @@ public class WebTestUtilsTests {
}
// @formatter:on
}
@Configuration
static class NoSecurityConfig {
}
@EnableWebSecurity
static class SecurityConfigWithDefaults extends WebSecurityConfigurerAdapter {
}
}