diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/builders/HttpConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/builders/HttpConfigurationTests.java index d732445b1a..44571562a8 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/builders/HttpConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/builders/HttpConfigurationTests.java @@ -25,19 +25,20 @@ import org.springframework.security.config.annotation.web.configuration.EnableWe import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.core.userdetails.PasswordEncodedUser; -import org.springframework.security.web.FilterChainProxy; import org.springframework.test.web.servlet.MockMvc; import org.springframework.web.filter.OncePerRequestFilter; -import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.ThrowableAssert.catchThrowable; +import static org.mockito.Mockito.*; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; @@ -51,9 +52,6 @@ public class HttpConfigurationTests { @Rule public final SpringTestRule spring = new SpringTestRule(); - @Autowired - private FilterChainProxy springSecurityFilterChain; - @Autowired private MockMvc mockMvc; @@ -93,16 +91,22 @@ public class HttpConfigurationTests { // https://github.com/SpringSource/spring-security-javaconfig/issues/104 @Test public void configureWhenAddFilterCasAuthenticationFilterThenFilterAdded() throws Exception { + CasAuthenticationFilterConfig.CAS_AUTHENTICATION_FILTER = spy(new CasAuthenticationFilter()); this.spring.register(CasAuthenticationFilterConfig.class).autowire(); - assertThat(this.findFilter(CasAuthenticationFilter.class, this.springSecurityFilterChain)).isNotNull(); + this.mockMvc.perform(get("/")); + + verify(CasAuthenticationFilterConfig.CAS_AUTHENTICATION_FILTER).doFilter( + any(ServletRequest.class), any(ServletResponse.class), any(FilterChain.class)); } @EnableWebSecurity static class CasAuthenticationFilterConfig extends WebSecurityConfigurerAdapter { + static CasAuthenticationFilter CAS_AUTHENTICATION_FILTER; + protected void configure(HttpSecurity http) throws Exception { http - .addFilter(new CasAuthenticationFilter()); + .addFilter(CAS_AUTHENTICATION_FILTER); } } @@ -131,22 +135,4 @@ public class HttpConfigurationTests { .httpBasic(); } } - - private T findFilter(Class filterType, FilterChainProxy filterChainProxy) { - return this.findFilter(filterType, filterChainProxy, 0); - } - - private T findFilter(Class filterType, FilterChainProxy filterChainProxy, int filterChainIndex) { - if (filterChainIndex >= filterChainProxy.getFilterChains().size()) { - return null; - } - - Filter filter = filterChainProxy.getFilterChains().get(filterChainIndex).getFilters() - .stream() - .filter(f -> f.getClass().isAssignableFrom(filterType)) - .findFirst() - .orElse(null); - - return (T) filter; - } }