From c6b8fe5e55d5b4e7cb1127d20365eaf01b665143 Mon Sep 17 00:00:00 2001 From: Luke Taylor Date: Sun, 3 Jan 2010 19:06:04 +0000 Subject: [PATCH] SEC-1346: Added missing 'return' statements after redirects. ConcurrentSessionFilter and SessionManagementFilter now return immediately after redirecting to the expired URL and invalid session URLs respectively. Extra tests added to check. --- .../web/session/ConcurrentSessionFilter.java | 3 +- .../web/session/SessionManagementFilter.java | 3 +- .../security/web/util/UrlUtils.java | 2 - .../ConcurrentSessionFilterTests.java | 113 +++++++----------- .../session/SessionManagementFilterTests.java | 35 +++++- 5 files changed, 79 insertions(+), 77 deletions(-) diff --git a/web/src/main/java/org/springframework/security/web/session/ConcurrentSessionFilter.java b/web/src/main/java/org/springframework/security/web/session/ConcurrentSessionFilter.java index 295ace08ec..e907ff208c 100644 --- a/web/src/main/java/org/springframework/security/web/session/ConcurrentSessionFilter.java +++ b/web/src/main/java/org/springframework/security/web/session/ConcurrentSessionFilter.java @@ -53,7 +53,6 @@ import org.springframework.web.filter.GenericFilterBean; * {@link org.springframework.security.web.session.HttpSessionEventPublisher} registered in web.xml.

* * @author Ben Alex - * @version $Id$ */ public class ConcurrentSessionFilter extends GenericFilterBean { //~ Instance fields ================================================================================================ @@ -91,6 +90,8 @@ public class ConcurrentSessionFilter extends GenericFilterBean { if (targetUrl != null) { redirectStrategy.sendRedirect(request, response, targetUrl); + + return; } else { response.getWriter().print("This session has been expired (possibly due to multiple concurrent " + "logins being attempted as the same user)."); diff --git a/web/src/main/java/org/springframework/security/web/session/SessionManagementFilter.java b/web/src/main/java/org/springframework/security/web/session/SessionManagementFilter.java index 4577f7a4e1..155cd75284 100644 --- a/web/src/main/java/org/springframework/security/web/session/SessionManagementFilter.java +++ b/web/src/main/java/org/springframework/security/web/session/SessionManagementFilter.java @@ -31,7 +31,6 @@ import org.springframework.web.filter.GenericFilterBean; * * @author Martin Algesten * @author Luke Taylor - * @version $Id$ * @since 2.0 */ public class SessionManagementFilter extends GenericFilterBean { @@ -87,6 +86,8 @@ public class SessionManagementFilter extends GenericFilterBean { if (invalidSessionUrl != null) { logger.debug("Redirecting to '" + invalidSessionUrl + "'"); redirectStrategy.sendRedirect(request, response, invalidSessionUrl); + + return; } } } diff --git a/web/src/main/java/org/springframework/security/web/util/UrlUtils.java b/web/src/main/java/org/springframework/security/web/util/UrlUtils.java index e2c48ef851..ff2a4f27f5 100644 --- a/web/src/main/java/org/springframework/security/web/util/UrlUtils.java +++ b/web/src/main/java/org/springframework/security/web/util/UrlUtils.java @@ -23,7 +23,6 @@ import javax.servlet.http.HttpServletRequest; * URL formatting conventions will affect all users.

* * @author Ben Alex - * @version $Id$ */ public final class UrlUtils { //~ Methods ======================================================================================================== @@ -94,7 +93,6 @@ public final class UrlUtils { /** * Obtains the web application-specific fragment of the URL. - */ private static String buildRequestUrl(String servletPath, String requestURI, String contextPath, String pathInfo, String queryString) { diff --git a/web/src/test/java/org/springframework/security/web/concurrent/ConcurrentSessionFilterTests.java b/web/src/test/java/org/springframework/security/web/concurrent/ConcurrentSessionFilterTests.java index 3174fb9498..1f85a8115a 100644 --- a/web/src/test/java/org/springframework/security/web/concurrent/ConcurrentSessionFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/concurrent/ConcurrentSessionFilterTests.java @@ -15,79 +15,70 @@ package org.springframework.security.web.concurrent; -import junit.framework.TestCase; -import org.springframework.mock.web.MockFilterConfig; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import java.util.Date; + +import javax.servlet.FilterChain; + +import org.junit.Test; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpSession; import org.springframework.security.core.session.SessionRegistry; import org.springframework.security.core.session.SessionRegistryImpl; +import org.springframework.security.web.DefaultRedirectStrategy; +import org.springframework.security.web.authentication.logout.LogoutHandler; +import org.springframework.security.web.authentication.logout.SecurityContextLogoutHandler; import org.springframework.security.web.session.ConcurrentSessionFilter; -import javax.servlet.Filter; -import javax.servlet.FilterChain; -import javax.servlet.FilterConfig; -import javax.servlet.ServletException; -import javax.servlet.ServletRequest; -import javax.servlet.ServletResponse; -import java.io.IOException; -import java.util.Date; - /** * Tests {@link ConcurrentSessionFilter}. * * @author Ben Alex - * @version $Id$ + * @author Luke Taylor */ -public class ConcurrentSessionFilterTests extends TestCase { +public class ConcurrentSessionFilterTests { - //~ Methods ======================================================================================================== - - private void executeFilterInContainerSimulator(FilterConfig filterConfig, Filter filter, ServletRequest request, - ServletResponse response, FilterChain filterChain) - throws ServletException, IOException { - filter.init(filterConfig); - filter.doFilter(request, response, filterChain); - filter.destroy(); - } - - public void testDetectsExpiredSessions() throws Exception { + @Test + public void detectsExpiredSessions() throws Exception { // Setup our HTTP request MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpSession session = new MockHttpSession(); request.setSession(session); MockHttpServletResponse response = new MockHttpServletResponse(); - MockFilterConfig config = new MockFilterConfig(null, null); - - // Setup our expectation that the filter chain will not be invoked, as we redirect to expiredUrl - MockFilterChain chain = new MockFilterChain(false); // Setup our test fixture and registry to want this session to be expired ConcurrentSessionFilter filter = new ConcurrentSessionFilter(); + filter.setRedirectStrategy(new DefaultRedirectStrategy()); + filter.setLogoutHandlers(new LogoutHandler[] {new SecurityContextLogoutHandler()}); + SessionRegistry registry = new SessionRegistryImpl(); registry.registerNewSession(session.getId(), "principal"); registry.getSessionInformation(session.getId()).expireNow(); filter.setSessionRegistry(registry); filter.setExpiredUrl("/expired.jsp"); + filter.afterPropertiesSet(); - // Test - executeFilterInContainerSimulator(config, filter, request, response, chain); + FilterChain fc = mock(FilterChain.class); + filter.doFilter(request, response, fc); + // Expect that the filter chain will not be invoked, as we redirect to expiredUrl + verifyZeroInteractions(fc); assertEquals("/expired.jsp", response.getRedirectedUrl()); } // As above, but with no expiredUrl set. - public void testReturnsExpectedMessageWhenNoExpiredUrlSet() throws Exception { + @Test + public void returnsExpectedMessageWhenNoExpiredUrlSet() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpSession session = new MockHttpSession(); request.setSession(session); MockHttpServletResponse response = new MockHttpServletResponse(); - MockFilterConfig config = new MockFilterConfig(null, null); - - MockFilterChain chain = new MockFilterChain(false); ConcurrentSessionFilter filter = new ConcurrentSessionFilter(); SessionRegistry registry = new SessionRegistryImpl(); @@ -95,35 +86,36 @@ public class ConcurrentSessionFilterTests extends TestCase { registry.getSessionInformation(session.getId()).expireNow(); filter.setSessionRegistry(registry); - executeFilterInContainerSimulator(config, filter, request, response, chain); + FilterChain fc = mock(FilterChain.class); + filter.doFilter(request, response, fc); + verifyZeroInteractions(fc); assertEquals("This session has been expired (possibly due to multiple concurrent logins being " + "attempted as the same user).", response.getContentAsString()); } - public void testDetectsMissingSessionRegistry() throws Exception { + @Test(expected=IllegalArgumentException.class) + public void detectsMissingSessionRegistry() throws Exception { ConcurrentSessionFilter filter = new ConcurrentSessionFilter(); - filter.setExpiredUrl("xcx"); - - try { - filter.afterPropertiesSet(); - fail("Should have thrown IAE"); - } catch (IllegalArgumentException expected) { - assertTrue(true); - } + filter.afterPropertiesSet(); } - public void testUpdatesLastRequestTime() throws Exception { + @Test(expected=IllegalArgumentException.class) + public void detectsInvalidUrl() throws Exception { + ConcurrentSessionFilter filter = new ConcurrentSessionFilter(); + filter.setExpiredUrl("ImNotValid"); + filter.afterPropertiesSet(); + } + + @Test + public void lastRequestTimeUpdatesCorrectly() throws Exception { // Setup our HTTP request MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpSession session = new MockHttpSession(); request.setSession(session); MockHttpServletResponse response = new MockHttpServletResponse(); - MockFilterConfig config = new MockFilterConfig(null, null); - - // Setup our expectation that the filter chain will be invoked, as our session hasn't expired - MockFilterChain chain = new MockFilterChain(true); + FilterChain fc = mock(FilterChain.class); // Setup our test fixture ConcurrentSessionFilter filter = new ConcurrentSessionFilter(); @@ -136,28 +128,9 @@ public class ConcurrentSessionFilterTests extends TestCase { Thread.sleep(1000); - // Test - executeFilterInContainerSimulator(config, filter, request, response, chain); + filter.doFilter(request, response, fc); + verify(fc).doFilter(request, response); assertTrue(registry.getSessionInformation(session.getId()).getLastRequest().after(lastRequest)); } - - //~ Inner Classes ================================================================================================== - - private class MockFilterChain implements FilterChain { - private boolean expectToProceed; - - public MockFilterChain(boolean expectToProceed) { - this.expectToProceed = expectToProceed; - } - - public void doFilter(ServletRequest request, ServletResponse response) - throws IOException, ServletException { - if (expectToProceed) { - assertTrue(true); - } else { - fail("Did not expect filter chain to proceed"); - } - } - } } diff --git a/web/src/test/java/org/springframework/security/web/session/SessionManagementFilterTests.java b/web/src/test/java/org/springframework/security/web/session/SessionManagementFilterTests.java index 2d21e95397..d1e7c5943c 100644 --- a/web/src/test/java/org/springframework/security/web/session/SessionManagementFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/session/SessionManagementFilterTests.java @@ -4,6 +4,7 @@ import static org.junit.Assert.*; import static org.mockito.Matchers.any; import static org.mockito.Mockito.*; +import javax.servlet.FilterChain; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -15,13 +16,14 @@ import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.web.DefaultRedirectStrategy; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.session.SessionAuthenticationException; import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; import org.springframework.security.web.context.SecurityContextRepository; /** - * * @author Luke Taylor - * @version $Id$ */ public class SessionManagementFilterTests { @@ -89,6 +91,29 @@ public class SessionManagementFilterTests { verifyNoMoreInteractions(strategy); } + @Test + public void strategyFailureInvokesFailureHandler() throws Exception { + SecurityContextRepository repo = mock(SecurityContextRepository.class); + // repo will return false to containsContext() + SessionAuthenticationStrategy strategy = mock(SessionAuthenticationStrategy.class); + + AuthenticationFailureHandler failureHandler = mock(AuthenticationFailureHandler.class); + SessionManagementFilter filter = new SessionManagementFilter(repo); + filter.setAuthenticationFailureHandler(failureHandler); + filter.setSessionAuthenticationStrategy(strategy); + HttpServletRequest request = new MockHttpServletRequest(); + HttpServletResponse response = new MockHttpServletResponse(); + FilterChain fc = mock(FilterChain.class); + authenticateUser(); + SessionAuthenticationException exception = new SessionAuthenticationException("Failure"); + doThrow(exception).when(strategy).onAuthentication( + SecurityContextHolder.getContext().getAuthentication(), request, response); + + filter.doFilter(request,response, fc); + verifyZeroInteractions(fc); + verify(failureHandler).onAuthenticationFailure(request, response, exception); + } + @Test public void responseIsRedirectedToTimeoutUrlIfSetAndSessionIsInvalid() throws Exception { SecurityContextRepository repo = mock(SecurityContextRepository.class); @@ -96,6 +121,7 @@ public class SessionManagementFilterTests { SessionAuthenticationStrategy strategy = mock(SessionAuthenticationStrategy.class); SessionManagementFilter filter = new SessionManagementFilter(repo); filter.setSessionAuthenticationStrategy(strategy); + filter.setRedirectStrategy(new DefaultRedirectStrategy()); MockHttpServletRequest request = new MockHttpServletRequest(); request.setRequestedSessionId("xxx"); request.setRequestedSessionIdValid(false); @@ -109,7 +135,10 @@ public class SessionManagementFilterTests { request.setRequestedSessionId("xxx"); request.setRequestedSessionIdValid(false); filter.setInvalidSessionUrl("/timedOut"); - filter.doFilter(request, response, new MockFilterChain()); + FilterChain fc = mock(FilterChain.class); + filter.doFilter(request, response, fc); + verifyZeroInteractions(fc); + assertEquals("/timedOut", response.getRedirectedUrl()); }