SEC-1603: Add support for injecting an AuthenticationSuccessHandler into RememberMeAuthenticationFilter.

This commit is contained in:
Luke Taylor 2011-01-06 13:02:38 +00:00
parent c1f2fa1983
commit 7fd3aa2b45
2 changed files with 100 additions and 81 deletions

View File

@ -31,32 +31,39 @@ import org.springframework.security.authentication.event.InteractiveAuthenticati
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.authentication.RememberMeServices; import org.springframework.security.web.authentication.RememberMeServices;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.filter.GenericFilterBean; import org.springframework.web.filter.GenericFilterBean;
/** /**
* Detects if there is no <code>Authentication</code> object in the <code>SecurityContext</code>, and populates it * Detects if there is no {@code Authentication} object in the {@code SecurityContext}, and populates the context with
* with a remember-me authentication token if a {@link org.springframework.security.web.authentication.RememberMeServices} * a remember-me authentication token if a {@link RememberMeServices} implementation so requests.
* implementation so requests.<p>Concrete <code>RememberMeServices</code> implementations will have their {@link * <p>
* org.springframework.security.web.authentication.RememberMeServices#autoLogin(HttpServletRequest, HttpServletResponse)} method * Concrete {@code RememberMeServices} implementations will have their
* called by this filter. The <code>Authentication</code> or <code>null</code> returned by that method will be placed * {@link RememberMeServices#autoLogin(HttpServletRequest, HttpServletResponse)}
* into the <code>SecurityContext</code>. The <code>AuthenticationManager</code> will be used, so that any concurrent * method called by this filter. If this method returns a non-null {@code Authentication} object, it will be passed
* session management or other authentication-specific behaviour can be achieved. This is the same pattern as with * to the {@code AuthenticationManager}, so that any authentication-specific behaviour can be achieved.
* other authentication mechanisms, which call the <code>AuthenticationManager</code> as part of their contract.</p> * The resulting {@code Authentication} (if successful) will be placed into the {@code SecurityContext}.
* <p>If authentication is successful, an {@link * <p>
* org.springframework.security.authentication.event.InteractiveAuthenticationSuccessEvent} will be published to the application * If authentication is successful, an {@link InteractiveAuthenticationSuccessEvent} will be published
* context. No events will be published if authentication was unsuccessful, because this would generally be recorded * to the application context. No events will be published if authentication was unsuccessful, because this would
* via an <code>AuthenticationManager</code>-specific application event.</p> * generally be recorded via an {@code AuthenticationManager}-specific application event.
* <p>
* Normally the request will be allowed to proceed regardless of whether authentication succeeds or fails. If
* some control over the destination for authenticated users is required, an {@link AuthenticationSuccessHandler}
* can be injected
* *
* @author Ben Alex * @author Ben Alex
* @author Luke Taylor
*/ */
public class RememberMeAuthenticationFilter extends GenericFilterBean implements ApplicationEventPublisherAware { public class RememberMeAuthenticationFilter extends GenericFilterBean implements ApplicationEventPublisherAware {
//~ Instance fields ================================================================================================ //~ Instance fields ================================================================================================
private ApplicationEventPublisher eventPublisher; private ApplicationEventPublisher eventPublisher;
private AuthenticationSuccessHandler successHandler;
private AuthenticationManager authenticationManager; private AuthenticationManager authenticationManager;
private RememberMeServices rememberMeServices; private RememberMeServices rememberMeServices;
@ -96,6 +103,13 @@ public class RememberMeAuthenticationFilter extends GenericFilterBean implements
eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent( eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(
SecurityContextHolder.getContext().getAuthentication(), this.getClass())); SecurityContextHolder.getContext().getAuthentication(), this.getClass()));
} }
if (successHandler != null) {
successHandler.onAuthenticationSuccess(request, response, rememberMeAuth);
return;
}
} catch (AuthenticationException authenticationException) { } catch (AuthenticationException authenticationException) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("SecurityContextHolder not populated with remember-me token, as " logger.debug("SecurityContextHolder not populated with remember-me token, as "
@ -121,17 +135,17 @@ public class RememberMeAuthenticationFilter extends GenericFilterBean implements
} }
/** /**
* Called if a remember-me token is presented and successfully authenticated by the <tt>RememberMeServices</tt> * Called if a remember-me token is presented and successfully authenticated by the {@code RememberMeServices}
* <tt>autoLogin</tt> method and the <tt>AuthenticationManager</tt>. * {@code autoLogin} method and the {@code AuthenticationManager}.
*/ */
protected void onSuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response, protected void onSuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response,
Authentication authResult) { Authentication authResult) {
} }
/** /**
* Called if the <tt>AuthenticationManager</tt> rejects the authentication object returned from the * Called if the {@code AuthenticationManager} rejects the authentication object returned from the
* <tt>RememberMeServices</tt> <tt>autoLogin</tt> method. This method will not be called when no remember-me * {@code RememberMeServices} {@code autoLogin} method. This method will not be called when no remember-me
* token is present in the request and <tt>autoLogin</tt> returns null. * token is present in the request and {@code autoLogin} reurns null.
*/ */
protected void onUnsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response, protected void onUnsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response,
AuthenticationException failed) { AuthenticationException failed) {
@ -152,4 +166,19 @@ public class RememberMeAuthenticationFilter extends GenericFilterBean implements
public void setRememberMeServices(RememberMeServices rememberMeServices) { public void setRememberMeServices(RememberMeServices rememberMeServices) {
this.rememberMeServices = rememberMeServices; this.rememberMeServices = rememberMeServices;
} }
/**
* Allows control over the destination a remembered user is sent to when they are successfully authenticated.
* By default, the filter will just allow the current request to proceed, but if an
* {@code AuthenticationSuccessHandler} is set, it will be invoked and the {@code doFilter()} method will return
* immediately, thus allowing the application to redirect the user to a specific URL, regardless of whatthe original
* request was for.
*
* @param successHandler the strategy to invoke immediately before returning from {@code doFilter()}.
*/
public void setAuthenticationSuccessHandler(AuthenticationSuccessHandler successHandler) {
Assert.notNull(successHandler, "successHandler cannot be null");
this.successHandler = successHandler;
}
} }

View File

@ -15,22 +15,11 @@
package org.springframework.security.web.authentication.rememberme; package org.springframework.security.web.authentication.rememberme;
import static org.junit.Assert.*;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Mockito.*; import static org.mockito.Mockito.*;
import java.io.IOException; import org.junit.*;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import junit.framework.TestCase;
import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.ApplicationEventPublisher;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
@ -42,6 +31,11 @@ import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.authentication.NullRememberMeServices; import org.springframework.security.web.authentication.NullRememberMeServices;
import org.springframework.security.web.authentication.RememberMeServices; import org.springframework.security.web.authentication.RememberMeServices;
import org.springframework.security.web.authentication.SimpleUrlAuthenticationSuccessHandler;
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
/** /**
@ -49,27 +43,23 @@ import org.springframework.security.web.authentication.RememberMeServices;
* *
* @author Ben Alex * @author Ben Alex
*/ */
public class RememberMeAuthenticationFilterTests extends TestCase { public class RememberMeAuthenticationFilterTests {
Authentication remembered = new TestingAuthenticationToken("remembered", "password","ROLE_REMEMBERED"); Authentication remembered = new TestingAuthenticationToken("remembered", "password","ROLE_REMEMBERED");
//~ Methods ======================================================================================================== //~ Methods ========================================================================================================
private void executeFilterInContainerSimulator(FilterConfig filterConfig, Filter filter, ServletRequest request, @Before
ServletResponse response, FilterChain filterChain) throws ServletException, IOException { public void setUp() {
// filter.init(filterConfig);
filter.doFilter(request, response, filterChain);
// filter.destroy();
}
protected void setUp() throws Exception {
SecurityContextHolder.clearContext(); SecurityContextHolder.clearContext();
} }
protected void tearDown() throws Exception { @After
public void tearDown() {
SecurityContextHolder.clearContext(); SecurityContextHolder.clearContext();
} }
public void testDetectsAuthenticationManagerProperty() throws Exception { @Test(expected = IllegalArgumentException.class)
public void testDetectsAuthenticationManagerProperty() {
RememberMeAuthenticationFilter filter = new RememberMeAuthenticationFilter(); RememberMeAuthenticationFilter filter = new RememberMeAuthenticationFilter();
filter.setAuthenticationManager(mock(AuthenticationManager.class)); filter.setAuthenticationManager(mock(AuthenticationManager.class));
filter.setRememberMeServices(new NullRememberMeServices()); filter.setRememberMeServices(new NullRememberMeServices());
@ -78,15 +68,11 @@ public class RememberMeAuthenticationFilterTests extends TestCase {
filter.setAuthenticationManager(null); filter.setAuthenticationManager(null);
try { filter.afterPropertiesSet();
filter.afterPropertiesSet();
fail("Should have thrown IllegalArgumentException");
} catch (IllegalArgumentException expected) {
assertTrue(true);
}
} }
public void testDetectsRememberMeServicesProperty() throws Exception { @Test(expected = IllegalArgumentException.class)
public void testDetectsRememberMeServicesProperty() {
RememberMeAuthenticationFilter filter = new RememberMeAuthenticationFilter(); RememberMeAuthenticationFilter filter = new RememberMeAuthenticationFilter();
filter.setAuthenticationManager(mock(AuthenticationManager.class)); filter.setAuthenticationManager(mock(AuthenticationManager.class));
@ -100,14 +86,10 @@ public class RememberMeAuthenticationFilterTests extends TestCase {
// check detects if made null // check detects if made null
filter.setRememberMeServices(null); filter.setRememberMeServices(null);
try { filter.afterPropertiesSet();
filter.afterPropertiesSet();
fail("Should have thrown IllegalArgumentException");
} catch (IllegalArgumentException expected) {
assertTrue(true);
}
} }
@Test
public void testOperationWhenAuthenticationExistsInContextHolder() throws Exception { public void testOperationWhenAuthenticationExistsInContextHolder() throws Exception {
// Put an Authentication object into the SecurityContextHolder // Put an Authentication object into the SecurityContextHolder
Authentication originalAuth = new TestingAuthenticationToken("user", "password","ROLE_A"); Authentication originalAuth = new TestingAuthenticationToken("user", "password","ROLE_A");
@ -121,14 +103,16 @@ public class RememberMeAuthenticationFilterTests extends TestCase {
// Test // Test
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
FilterChain fc = mock(FilterChain.class);
request.setRequestURI("x"); request.setRequestURI("x");
executeFilterInContainerSimulator(mock(FilterConfig.class), filter, request, new MockHttpServletResponse(), filter.doFilter(request, new MockHttpServletResponse(), fc);
new MockFilterChain(true));
// Ensure filter didn't change our original object // Ensure filter didn't change our original object
assertEquals(originalAuth, SecurityContextHolder.getContext().getAuthentication()); assertSame(originalAuth, SecurityContextHolder.getContext().getAuthentication());
verify(fc).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
} }
@Test
public void testOperationWhenNoAuthenticationInContextHolder() throws Exception { public void testOperationWhenNoAuthenticationInContextHolder() throws Exception {
RememberMeAuthenticationFilter filter = new RememberMeAuthenticationFilter(); RememberMeAuthenticationFilter filter = new RememberMeAuthenticationFilter();
@ -139,15 +123,17 @@ public class RememberMeAuthenticationFilterTests extends TestCase {
filter.afterPropertiesSet(); filter.afterPropertiesSet();
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
FilterChain fc = mock(FilterChain.class);
request.setRequestURI("x"); request.setRequestURI("x");
executeFilterInContainerSimulator(mock(FilterConfig.class), filter, request, new MockHttpServletResponse(), filter.doFilter(request, new MockHttpServletResponse(), fc);
new MockFilterChain(true));
// Ensure filter setup with our remembered authentication object // Ensure filter setup with our remembered authentication object
assertEquals(remembered, SecurityContextHolder.getContext().getAuthentication()); assertSame(remembered, SecurityContextHolder.getContext().getAuthentication());
verify(fc).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
} }
public void testOnUnsuccessfulLoginIsCalledWhenProviderRejectsAuth() throws Exception { @Test
public void onUnsuccessfulLoginIsCalledWhenProviderRejectsAuth() throws Exception {
final Authentication failedAuth = new TestingAuthenticationToken("failed", ""); final Authentication failedAuth = new TestingAuthenticationToken("failed", "");
RememberMeAuthenticationFilter filter = new RememberMeAuthenticationFilter() { RememberMeAuthenticationFilter filter = new RememberMeAuthenticationFilter() {
@ -164,32 +150,36 @@ public class RememberMeAuthenticationFilterTests extends TestCase {
filter.afterPropertiesSet(); filter.afterPropertiesSet();
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
FilterChain fc = mock(FilterChain.class);
request.setRequestURI("x"); request.setRequestURI("x");
executeFilterInContainerSimulator(mock(FilterConfig.class), filter, request, new MockHttpServletResponse(), filter.doFilter(request, new MockHttpServletResponse(), fc);
new MockFilterChain(true));
assertEquals(failedAuth, SecurityContextHolder.getContext().getAuthentication()); assertSame(failedAuth, SecurityContextHolder.getContext().getAuthentication());
verify(fc).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
@Test
public void authenticationSuccessHandlerIsInvokedOnSuccessfulAuthenticationIfSet() throws Exception {
RememberMeAuthenticationFilter filter = new RememberMeAuthenticationFilter();
AuthenticationManager am = mock(AuthenticationManager.class);
when(am.authenticate(remembered)).thenReturn(remembered);
filter.setAuthenticationManager(am);
filter.setRememberMeServices(new MockRememberMeServices(remembered));
filter.setAuthenticationSuccessHandler(new SimpleUrlAuthenticationSuccessHandler("/target"));
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain fc = mock(FilterChain.class);
request.setRequestURI("x");
filter.doFilter(request, response, fc);
assertEquals("/target", response.getRedirectedUrl());
// Should return after success handler is invoked, so chain should not proceed
verifyZeroInteractions(fc);
} }
//~ Inner Classes ================================================================================================== //~ Inner Classes ==================================================================================================
private class MockFilterChain implements FilterChain {
private boolean expectToProceed;
public MockFilterChain(boolean expectToProceed) {
this.expectToProceed = expectToProceed;
}
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
if (expectToProceed) {
assertTrue(true);
} else {
fail("Did not expect filter chain to proceed");
}
}
}
private class MockRememberMeServices implements RememberMeServices { private class MockRememberMeServices implements RememberMeServices {
private Authentication authToReturn; private Authentication authToReturn;