diff --git a/web/src/main/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilter.java b/web/src/main/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilter.java index 0a77d570a2..2ab32598ce 100755 --- a/web/src/main/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilter.java @@ -143,6 +143,8 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi logger.debug("Pre-authenticated principal has changed to " + principal + " and will be reauthenticated"); if (invalidateSessionOnPrincipalChange) { + SecurityContextHolder.clearContext(); + HttpSession session = request.getSession(false); if (session != null) { diff --git a/web/src/test/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilterTests.java index a61f1436fa..cd4cdd907c 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilterTests.java @@ -7,6 +7,7 @@ import static org.mockito.Mockito.*; import javax.servlet.FilterChain; import javax.servlet.http.HttpServletRequest; +import org.junit.After; import org.junit.Before; import org.junit.Test; import org.mockito.invocation.InvocationOnMock; @@ -16,6 +17,7 @@ import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.BadCredentialsException; +import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; @@ -33,7 +35,12 @@ public class AbstractPreAuthenticatedProcessingFilterTests { return "doesntmatter"; } }; - SecurityContextHolder.getContext().setAuthentication(null); + SecurityContextHolder.clearContext(); + } + + @After + public void tearDown() { + SecurityContextHolder.clearContext(); } @Test @@ -80,6 +87,31 @@ public class AbstractPreAuthenticatedProcessingFilterTests { testDoFilter(false); } + // SEC-1968 + @Test + public void nullPreAuthenticationClearsPreviousUser() throws Exception { + SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("oldUser", "pass","ROLE_USER")); + ConcretePreAuthenticatedProcessingFilter filter = new ConcretePreAuthenticatedProcessingFilter(); + filter.principal = null; + filter.setCheckForPrincipalChanges(true); + + filter.doFilter(new MockHttpServletRequest(), new MockHttpServletResponse(), new MockFilterChain()); + + assertEquals(null, SecurityContextHolder.getContext().getAuthentication()); + } + + @Test + public void nullPreAuthenticationPerservesPreviousUserCheckPrincipalChangesFalse() throws Exception { + TestingAuthenticationToken authentication = new TestingAuthenticationToken("oldUser", "pass","ROLE_USER"); + SecurityContextHolder.getContext().setAuthentication(authentication); + ConcretePreAuthenticatedProcessingFilter filter = new ConcretePreAuthenticatedProcessingFilter(); + filter.principal = null; + + filter.doFilter(new MockHttpServletRequest(), new MockHttpServletResponse(), new MockFilterChain()); + + assertEquals(authentication, SecurityContextHolder.getContext().getAuthentication()); + } + private void testDoFilter(boolean grantAccess) throws Exception { MockHttpServletRequest req = new MockHttpServletRequest(); MockHttpServletResponse res = new MockHttpServletResponse(); @@ -107,8 +139,9 @@ public class AbstractPreAuthenticatedProcessingFilterTests { } private static class ConcretePreAuthenticatedProcessingFilter extends AbstractPreAuthenticatedProcessingFilter { + private String principal = "testPrincipal"; protected Object getPreAuthenticatedPrincipal(HttpServletRequest httpRequest) { - return "testPrincipal"; + return principal; } protected Object getPreAuthenticatedCredentials(HttpServletRequest httpRequest) { return "testCredentials";