diff --git a/web/src/main/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilter.java b/web/src/main/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilter.java index a47e010279..e3f758721e 100755 --- a/web/src/main/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilter.java @@ -33,7 +33,7 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.web.WebAttributes; -import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; +import org.springframework.security.web.authentication.*; import org.springframework.util.Assert; import org.springframework.web.filter.GenericFilterBean; @@ -84,6 +84,8 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi private boolean continueFilterChainOnUnsuccessfulAuthentication = true; private boolean checkForPrincipalChanges; private boolean invalidateSessionOnPrincipalChange = true; + private AuthenticationSuccessHandler authenticationSuccessHandler = null; + private AuthenticationFailureHandler authenticationFailureHandler = null; /** * Check whether all required properties have been set. @@ -156,7 +158,7 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi /** * Do the actual authentication for a pre-authenticated user. */ - private void doAuthenticate(HttpServletRequest request, HttpServletResponse response) { + private void doAuthenticate(HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException { Authentication authResult; Object principal = getPreAuthenticatedPrincipal(request); @@ -229,7 +231,7 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi * manager into the secure context. */ protected void successfulAuthentication(HttpServletRequest request, - HttpServletResponse response, Authentication authResult) { + HttpServletResponse response, Authentication authResult) throws IOException, ServletException { if (logger.isDebugEnabled()) { logger.debug("Authentication success: " + authResult); } @@ -239,6 +241,10 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent( authResult, this.getClass())); } + + if(authenticationSuccessHandler != null) { + authenticationSuccessHandler.onAuthenticationSuccess(request, response, authResult); + } } /** @@ -248,13 +254,17 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi * Caches the failure exception as a request attribute */ protected void unsuccessfulAuthentication(HttpServletRequest request, - HttpServletResponse response, AuthenticationException failed) { + HttpServletResponse response, AuthenticationException failed) throws IOException, ServletException { SecurityContextHolder.clearContext(); if (logger.isDebugEnabled()) { logger.debug("Cleared security context due to exception", failed); } request.setAttribute(WebAttributes.AUTHENTICATION_EXCEPTION, failed); + + if(authenticationFailureHandler != null) { + authenticationFailureHandler.onAuthenticationFailure(request, response, failed); + } } /** @@ -324,6 +334,20 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi this.invalidateSessionOnPrincipalChange = invalidateSessionOnPrincipalChange; } + /** + * Sets the strategy used to handle a successful authentication. + */ + public void setAuthenticationSuccessHandler(AuthenticationSuccessHandler authenticationSuccessHandler) { + this.authenticationSuccessHandler = authenticationSuccessHandler; + } + + /** + * Sets the strategy used to handle a failed authentication. + */ + public void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) { + this.authenticationFailureHandler = authenticationFailureHandler; + } + /** * Override to extract the principal information from the current request */ diff --git a/web/src/test/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilterTests.java index a3b55b14b5..6969f62002 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilterTests.java @@ -40,9 +40,13 @@ import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.userdetails.User; +import org.springframework.security.web.WebAttributes; +import org.springframework.security.web.authentication.ForwardAuthenticationFailureHandler; +import org.springframework.security.web.authentication.ForwardAuthenticationSuccessHandler; /** * @@ -206,6 +210,51 @@ public class AbstractPreAuthenticatedProcessingFilterTests { verify(am).authenticate(any(PreAuthenticatedAuthenticationToken.class)); } + @Test + public void callsAuthenticationSuccessHandlerOnSuccessfulAuthentication() throws Exception { + Object currentPrincipal = "currentUser"; + TestingAuthenticationToken authRequest = new TestingAuthenticationToken( + currentPrincipal, "something", "ROLE_USER"); + SecurityContextHolder.getContext().setAuthentication(authRequest); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + MockFilterChain chain = new MockFilterChain(); + + ConcretePreAuthenticatedProcessingFilter filter = new ConcretePreAuthenticatedProcessingFilter(); + filter.setAuthenticationSuccessHandler(new ForwardAuthenticationSuccessHandler("/forwardUrl")); + filter.setCheckForPrincipalChanges(true); + filter.principal = "newUser"; + AuthenticationManager am = mock(AuthenticationManager.class); + filter.setAuthenticationManager(am); + filter.afterPropertiesSet(); + + filter.doFilter(request, response, chain); + + verify(am).authenticate(any(PreAuthenticatedAuthenticationToken.class)); + assertThat(response.getForwardedUrl()).isEqualTo("/forwardUrl"); + } + + @Test + public void callsAuthenticationFailureHandlerOnFailedAuthentication() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + MockFilterChain chain = new MockFilterChain(); + + ConcretePreAuthenticatedProcessingFilter filter = new ConcretePreAuthenticatedProcessingFilter(); + filter.setAuthenticationFailureHandler(new ForwardAuthenticationFailureHandler("/forwardUrl")); + filter.setCheckForPrincipalChanges(true); + AuthenticationManager am = mock(AuthenticationManager.class); + when(am.authenticate(any(PreAuthenticatedAuthenticationToken.class))).thenThrow(new PreAuthenticatedCredentialsNotFoundException("invalid")); + filter.setAuthenticationManager(am); + filter.afterPropertiesSet(); + + filter.doFilter(request, response, chain); + + verify(am).authenticate(any(PreAuthenticatedAuthenticationToken.class)); + assertThat(response.getForwardedUrl()).isEqualTo("/forwardUrl"); + assertThat(request.getAttribute(WebAttributes.AUTHENTICATION_EXCEPTION)).isNotNull(); + } + // SEC-2078 @Test public void requiresAuthenticationFalsePrincipalNotString() throws Exception {