Fix HttpServlet3RequestFactory Logout Handlers

Previously there was a problem with Servlet API logout integration
when Servlet API was configured before log out.

This ensures that logout handlers is a reference to the logout handlers
vs copying the logout handlers. This ensures that the ordering does not
matter.

Closes gh-4760
This commit is contained in:
Rob Winch 2020-03-30 16:18:02 -05:00
parent b055f8bb25
commit 91728ef53b
2 changed files with 46 additions and 6 deletions

View File

@ -22,6 +22,7 @@ import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.AuthenticationTrustResolver; import org.springframework.security.authentication.AuthenticationTrustResolver;
@ -44,10 +45,14 @@ import org.springframework.security.web.authentication.logout.LogoutHandler;
import org.springframework.security.web.authentication.logout.LogoutSuccessEventPublishingLogoutHandler; import org.springframework.security.web.authentication.logout.LogoutSuccessEventPublishingLogoutHandler;
import org.springframework.security.web.servletapi.SecurityContextHolderAwareRequestFilter; import org.springframework.security.web.servletapi.SecurityContextHolderAwareRequestFilter;
import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.context.ConfigurableWebApplicationContext;
import javax.servlet.Filter; import javax.servlet.Filter;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
@ -60,6 +65,7 @@ import static org.springframework.security.config.Customizer.withDefaults;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user;
import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
@ -329,6 +335,39 @@ public class ServletApiConfigurerTests {
} }
} }
@Test
public void logoutServletApiWhenCsrfDisabled() throws Exception {
ConfigurableWebApplicationContext context = this.spring.register(CsrfDisabledConfig.class).getContext();
MockMvc mockMvc = MockMvcBuilders.webAppContextSetup(context)
.apply(springSecurity())
.build();
MvcResult mvcResult = mockMvc.perform(get("/"))
.andReturn();
assertThat(mvcResult.getRequest().getSession(false)).isNull();
}
@Configuration
@EnableWebSecurity
static class CsrfDisabledConfig extends WebSecurityConfigurerAdapter {
@Override
protected void configure(HttpSecurity http) throws Exception {
// @formatter:off
http
.csrf().disable();
// @formatter:on
}
@RestController
static class LogoutController {
@GetMapping("/")
String logout(HttpServletRequest request) throws ServletException {
request.getSession().setAttribute("foo", "bar");
request.logout();
return "logout";
}
}
}
private <T extends Filter> T getFilter(Class<T> filterClass) { private <T extends Filter> T getFilter(Class<T> filterClass) {
return (T) getFilters().stream() return (T) getFilters().stream()
.filter(filterClass::isInstance) .filter(filterClass::isInstance)

View File

@ -42,7 +42,6 @@ import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.authentication.logout.CompositeLogoutHandler;
import org.springframework.security.web.authentication.logout.LogoutHandler; import org.springframework.security.web.authentication.logout.LogoutHandler;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
@ -80,7 +79,7 @@ final class HttpServlet3RequestFactory implements HttpServletRequestFactory {
private AuthenticationTrustResolver trustResolver = new AuthenticationTrustResolverImpl(); private AuthenticationTrustResolver trustResolver = new AuthenticationTrustResolverImpl();
private AuthenticationEntryPoint authenticationEntryPoint; private AuthenticationEntryPoint authenticationEntryPoint;
private AuthenticationManager authenticationManager; private AuthenticationManager authenticationManager;
private LogoutHandler logoutHandler; private List<LogoutHandler> logoutHandlers;
HttpServlet3RequestFactory(String rolePrefix) { HttpServlet3RequestFactory(String rolePrefix) {
this.rolePrefix = rolePrefix; this.rolePrefix = rolePrefix;
@ -144,7 +143,7 @@ final class HttpServlet3RequestFactory implements HttpServletRequestFactory {
* {@link HttpServletRequest#logout()}. * {@link HttpServletRequest#logout()}.
*/ */
public void setLogoutHandlers(List<LogoutHandler> logoutHandlers) { public void setLogoutHandlers(List<LogoutHandler> logoutHandlers) {
this.logoutHandler = CollectionUtils.isEmpty(logoutHandlers) ? null : new CompositeLogoutHandler(logoutHandlers); this.logoutHandlers = logoutHandlers;
} }
/** /**
@ -244,8 +243,8 @@ final class HttpServlet3RequestFactory implements HttpServletRequestFactory {
@Override @Override
public void logout() throws ServletException { public void logout() throws ServletException {
LogoutHandler handler = HttpServlet3RequestFactory.this.logoutHandler; List<LogoutHandler> handlers = HttpServlet3RequestFactory.this.logoutHandlers;
if (handler == null) { if (CollectionUtils.isEmpty(handlers)) {
HttpServlet3RequestFactory.this.logger.debug( HttpServlet3RequestFactory.this.logger.debug(
"logoutHandlers is null, so allowing original HttpServletRequest to handle logout"); "logoutHandlers is null, so allowing original HttpServletRequest to handle logout");
super.logout(); super.logout();
@ -253,8 +252,10 @@ final class HttpServlet3RequestFactory implements HttpServletRequestFactory {
} }
Authentication authentication = SecurityContextHolder.getContext() Authentication authentication = SecurityContextHolder.getContext()
.getAuthentication(); .getAuthentication();
for (LogoutHandler handler : handlers) {
handler.logout(this, this.response, authentication); handler.logout(this, this.response, authentication);
} }
}
private boolean isAuthenticated() { private boolean isAuthenticated() {
Principal userPrincipal = getUserPrincipal(); Principal userPrincipal = getUserPrincipal();