diff --git a/core/src/main/java/org/springframework/security/ui/rememberme/RememberMeProcessingFilter.java b/core/src/main/java/org/springframework/security/ui/rememberme/RememberMeProcessingFilter.java index f93815a5f3..e2e840f1ef 100644 --- a/core/src/main/java/org/springframework/security/ui/rememberme/RememberMeProcessingFilter.java +++ b/core/src/main/java/org/springframework/security/ui/rememberme/RememberMeProcessingFilter.java @@ -89,6 +89,8 @@ public class RememberMeProcessingFilter extends SpringSecurityFilter implements // Store to SecurityContextHolder SecurityContextHolder.getContext().setAuthentication(rememberMeAuth); + onSuccessfulAuthentication(request, response, rememberMeAuth); + if (logger.isDebugEnabled()) { logger.debug("SecurityContextHolder populated with remember-me token: '" + SecurityContextHolder.getContext().getAuthentication() + "'"); @@ -107,6 +109,8 @@ public class RememberMeProcessingFilter extends SpringSecurityFilter implements } rememberMeServices.loginFail(request, response); + + onUnsuccessfulAuthentication(request, response, authenticationException); } } @@ -121,6 +125,23 @@ public class RememberMeProcessingFilter extends SpringSecurityFilter implements } } + /** + * Called if a remember-me token is presented and successfully authenticated by the RememberMeServices + * autoLogin method and the AuthenticationManager. + */ + protected void onSuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response, + Authentication authResult) { + } + + /** + * Called if the AuthenticationManager rejects the authentication object returned from the + * RememberMeServices autoLogin method. This method will not be called when no remember-me + * token is present in the request and autoLogin returns null. + */ + protected void onUnsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response, + AuthenticationException failed) { + } + public RememberMeServices getRememberMeServices() { return rememberMeServices; } diff --git a/core/src/test/java/org/springframework/security/MockApplicationEventPublisher.java b/core/src/test/java/org/springframework/security/MockApplicationEventPublisher.java new file mode 100644 index 0000000000..77e3651c87 --- /dev/null +++ b/core/src/test/java/org/springframework/security/MockApplicationEventPublisher.java @@ -0,0 +1,32 @@ +package org.springframework.security; + +import org.springframework.context.ApplicationEventPublisher; +import org.springframework.context.ApplicationEvent; + +/** + * @author Luke Taylor + * @version $Id$ + */ +public class MockApplicationEventPublisher implements ApplicationEventPublisher { + private Boolean expectedEvent; + private ApplicationEvent lastEvent; + + public MockApplicationEventPublisher() { + } + + public MockApplicationEventPublisher(boolean expectedEvent) { + this.expectedEvent = Boolean.valueOf(expectedEvent); + } + + public void publishEvent(ApplicationEvent event) { + if (expectedEvent != null && !expectedEvent.booleanValue()) { + throw new IllegalStateException("The ApplicationEventPublisher did not expect to receive this event"); + } + + lastEvent = event; + } + + public ApplicationEvent getLastEvent() { + return lastEvent; + } +} diff --git a/core/src/test/java/org/springframework/security/providers/ProviderManagerTests.java b/core/src/test/java/org/springframework/security/providers/ProviderManagerTests.java index 0a979a73de..2a21fe5c63 100644 --- a/core/src/test/java/org/springframework/security/providers/ProviderManagerTests.java +++ b/core/src/test/java/org/springframework/security/providers/ProviderManagerTests.java @@ -20,10 +20,9 @@ import org.springframework.security.AuthenticationException; import org.springframework.security.AuthenticationServiceException; import org.springframework.security.GrantedAuthority; import org.springframework.security.GrantedAuthorityImpl; +import org.springframework.security.MockApplicationEventPublisher; import org.springframework.security.concurrent.ConcurrentSessionControllerImpl; import org.springframework.security.concurrent.NullConcurrentSessionController; -import org.springframework.context.ApplicationEvent; -import org.springframework.context.ApplicationEventPublisher; import junit.framework.TestCase; @@ -228,20 +227,6 @@ public class ProviderManagerTests extends TestCase { //~ Inner Classes ================================================================================================== - private class MockApplicationEventPublisher implements ApplicationEventPublisher { - private boolean expectedEvent; - - public MockApplicationEventPublisher(boolean expectedEvent) { - this.expectedEvent = expectedEvent; - } - - public void publishEvent(ApplicationEvent event) { - if (!expectedEvent) { - throw new IllegalStateException("The ApplicationEventPublisher did not expect to receive this event"); - } - } - } - private class MockProvider implements AuthenticationProvider { public Authentication authenticate(Authentication authentication) throws AuthenticationException { if (supports(authentication.getClass())) { diff --git a/core/src/test/java/org/springframework/security/ui/basicauth/BasicProcessingFilterTests.java b/core/src/test/java/org/springframework/security/ui/basicauth/BasicProcessingFilterTests.java index 287f21ef59..9a55751959 100644 --- a/core/src/test/java/org/springframework/security/ui/basicauth/BasicProcessingFilterTests.java +++ b/core/src/test/java/org/springframework/security/ui/basicauth/BasicProcessingFilterTests.java @@ -19,6 +19,7 @@ import org.springframework.security.MockAuthenticationEntryPoint; import org.springframework.security.MockAuthenticationManager; import org.springframework.security.MockFilterChain; import org.springframework.security.MockFilterConfig; +import org.springframework.security.MockApplicationEventPublisher; import org.springframework.security.context.SecurityContextHolder; @@ -35,9 +36,6 @@ import org.apache.commons.codec.binary.Base64; import org.jmock.Mock; import org.jmock.MockObjectTestCase; -import org.springframework.context.ApplicationEvent; -import org.springframework.context.ApplicationEventPublisher; - import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpSession; @@ -66,7 +64,6 @@ public class BasicProcessingFilterTests extends MockObjectTestCase { //~ Constructors =================================================================================================== public BasicProcessingFilterTests() { - super(); } public BasicProcessingFilterTests(String arg0) { @@ -91,10 +88,6 @@ public class BasicProcessingFilterTests extends MockObjectTestCase { return response; } - public static void main(String[] args) { - junit.textui.TestRunner.run(BasicProcessingFilterTests.class); - } - protected void setUp() throws Exception { super.setUp(); SecurityContextHolder.clearContext(); @@ -123,8 +116,7 @@ public class BasicProcessingFilterTests extends MockObjectTestCase { SecurityContextHolder.clearContext(); } - public void testDoFilterWithNonHttpServletRequestDetected() - throws Exception { + public void testDoFilterWithNonHttpServletRequestDetected() throws Exception { BasicProcessingFilter filter = new BasicProcessingFilter(); try { @@ -135,8 +127,7 @@ public class BasicProcessingFilterTests extends MockObjectTestCase { } } - public void testDoFilterWithNonHttpServletResponseDetected() - throws Exception { + public void testDoFilterWithNonHttpServletResponseDetected() throws Exception { BasicProcessingFilter filter = new BasicProcessingFilter(); try { @@ -147,8 +138,7 @@ public class BasicProcessingFilterTests extends MockObjectTestCase { } } - public void testFilterIgnoresRequestsContainingNoAuthorizationHeader() - throws Exception { + public void testFilterIgnoresRequestsContainingNoAuthorizationHeader() throws Exception { // Setup our HTTP request MockHttpServletRequest request = new MockHttpServletRequest(); request.setServletPath("/some_file.html"); @@ -168,8 +158,7 @@ public class BasicProcessingFilterTests extends MockObjectTestCase { assertTrue(filter.getAuthenticationEntryPoint() != null); } - public void testInvalidBasicAuthorizationTokenIsIgnored() - throws Exception { + public void testInvalidBasicAuthorizationTokenIsIgnored() throws Exception { // Setup our HTTP request String token = "NOT_A_VALID_TOKEN_AS_MISSING_COLON"; MockHttpServletRequest request = new MockHttpServletRequest(); @@ -200,8 +189,7 @@ public class BasicProcessingFilterTests extends MockObjectTestCase { ((UserDetails) SecurityContextHolder.getContext().getAuthentication().getPrincipal()).getUsername()); } - public void testOtherAuthorizationSchemeIsIgnored() - throws Exception { + public void testOtherAuthorizationSchemeIsIgnored() throws Exception { // Setup our HTTP request MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("Authorization", "SOME_OTHER_AUTHENTICATION_SCHEME"); @@ -213,8 +201,7 @@ public class BasicProcessingFilterTests extends MockObjectTestCase { assertNull(SecurityContextHolder.getContext().getAuthentication()); } - public void testStartupDetectsMissingAuthenticationEntryPoint() - throws Exception { + public void testStartupDetectsMissingAuthenticationEntryPoint() throws Exception { try { BasicProcessingFilter filter = new BasicProcessingFilter(); filter.setAuthenticationManager(new MockAuthenticationManager()); @@ -225,8 +212,7 @@ public class BasicProcessingFilterTests extends MockObjectTestCase { } } - public void testStartupDetectsMissingAuthenticationManager() - throws Exception { + public void testStartupDetectsMissingAuthenticationManager() throws Exception { try { BasicProcessingFilter filter = new BasicProcessingFilter(); filter.setAuthenticationEntryPoint(new MockAuthenticationEntryPoint("x")); @@ -237,8 +223,7 @@ public class BasicProcessingFilterTests extends MockObjectTestCase { } } - public void testSuccessLoginThenFailureLoginResultsInSessionLoosingToken() - throws Exception { + public void testSuccessLoginThenFailureLoginResultsInSessionLosingToken() throws Exception { // Setup our HTTP request String token = "rod:koala"; MockHttpServletRequest request = new MockHttpServletRequest(); @@ -268,8 +253,7 @@ public class BasicProcessingFilterTests extends MockObjectTestCase { assertEquals(401, response.getStatus()); } - public void testWrongPasswordContinuesFilterChainIfIgnoreFailureIsTrue() - throws Exception { + public void testWrongPasswordContinuesFilterChainIfIgnoreFailureIsTrue() throws Exception { // Setup our HTTP request String token = "rod:WRONG_PASSWORD"; MockHttpServletRequest request = new MockHttpServletRequest(); @@ -286,8 +270,7 @@ public class BasicProcessingFilterTests extends MockObjectTestCase { assertNull(SecurityContextHolder.getContext().getAuthentication()); } - public void testWrongPasswordReturnsForbiddenIfIgnoreFailureIsFalse() - throws Exception { + public void testWrongPasswordReturnsForbiddenIfIgnoreFailureIsFalse() throws Exception { // Setup our HTTP request String token = "rod:WRONG_PASSWORD"; MockHttpServletRequest request = new MockHttpServletRequest(); @@ -302,12 +285,4 @@ public class BasicProcessingFilterTests extends MockObjectTestCase { assertNull(SecurityContextHolder.getContext().getAuthentication()); assertEquals(401, response.getStatus()); } - - //~ Inner Classes ================================================================================================== - - private class MockApplicationEventPublisher implements ApplicationEventPublisher { - public MockApplicationEventPublisher() {} - - public void publishEvent(ApplicationEvent event) {} - } } diff --git a/core/src/test/java/org/springframework/security/ui/rememberme/RememberMeProcessingFilterTests.java b/core/src/test/java/org/springframework/security/ui/rememberme/RememberMeProcessingFilterTests.java index ef95a6fcc9..24eb94f86f 100644 --- a/core/src/test/java/org/springframework/security/ui/rememberme/RememberMeProcessingFilterTests.java +++ b/core/src/test/java/org/springframework/security/ui/rememberme/RememberMeProcessingFilterTests.java @@ -20,6 +20,8 @@ import org.springframework.security.GrantedAuthority; import org.springframework.security.GrantedAuthorityImpl; import org.springframework.security.MockAuthenticationManager; import org.springframework.security.MockFilterConfig; +import org.springframework.security.AuthenticationException; +import org.springframework.security.MockApplicationEventPublisher; import org.springframework.security.context.SecurityContextHolder; import org.springframework.security.providers.TestingAuthenticationToken; import org.springframework.mock.web.MockHttpServletRequest; @@ -48,7 +50,6 @@ public class RememberMeProcessingFilterTests extends TestCase { //~ Constructors =================================================================================================== public RememberMeProcessingFilterTests() { - super(); } public RememberMeProcessingFilterTests(String arg0) { @@ -115,36 +116,7 @@ public class RememberMeProcessingFilterTests extends TestCase { } } - public void testDoFilterWithNonHttpServletRequestDetected() - throws Exception { - RememberMeProcessingFilter filter = new RememberMeProcessingFilter(); - filter.setAuthenticationManager(new MockAuthenticationManager()); - - try { - filter.doFilter(null, new MockHttpServletResponse(), new MockFilterChain()); - fail("Should have thrown ServletException"); - } catch (ServletException expected) { - assertEquals("Can only process HttpServletRequest", expected.getMessage()); - } - } - - public void testDoFilterWithNonHttpServletResponseDetected() - throws Exception { - RememberMeProcessingFilter filter = new RememberMeProcessingFilter(); - filter.setAuthenticationManager(new MockAuthenticationManager()); - - try { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setRequestURI("dc"); - filter.doFilter(request, null, new MockFilterChain()); - fail("Should have thrown ServletException"); - } catch (ServletException expected) { - assertEquals("Can only process HttpServletResponse", expected.getMessage()); - } - } - - public void testOperationWhenAuthenticationExistsInContextHolder() - throws Exception { + public void testOperationWhenAuthenticationExistsInContextHolder() throws Exception { // Put an Authentication object into the SecurityContextHolder Authentication originalAuth = new TestingAuthenticationToken("user", "password", new GrantedAuthority[] {new GrantedAuthorityImpl("ROLE_A")}); @@ -168,8 +140,7 @@ public class RememberMeProcessingFilterTests extends TestCase { assertEquals(originalAuth, SecurityContextHolder.getContext().getAuthentication()); } - public void testOperationWhenNoAuthenticationInContextHolder() - throws Exception { + public void testOperationWhenNoAuthenticationInContextHolder() throws Exception { Authentication remembered = new TestingAuthenticationToken("remembered", "password", new GrantedAuthority[] {new GrantedAuthorityImpl("ROLE_REMEMBERED")}); RememberMeProcessingFilter filter = new RememberMeProcessingFilter(); @@ -186,6 +157,30 @@ public class RememberMeProcessingFilterTests extends TestCase { assertEquals(remembered, SecurityContextHolder.getContext().getAuthentication()); } + public void testOnunsuccessfulLoginIsCalledWhenProviderRejectsAuth() throws Exception { + Authentication remembered = new TestingAuthenticationToken("remembered", "password", + new GrantedAuthority[] {new GrantedAuthorityImpl("ROLE_REMEMBERED")}); + final Authentication failedAuth = new TestingAuthenticationToken("failed", "", null); + + RememberMeProcessingFilter filter = new RememberMeProcessingFilter() { + protected void onUnsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response, AuthenticationException failed) { + super.onUnsuccessfulAuthentication(request, response, failed); + SecurityContextHolder.getContext().setAuthentication(failedAuth); + } + }; + filter.setAuthenticationManager(new MockAuthenticationManager(false)); + filter.setRememberMeServices(new MockRememberMeServices(remembered)); + filter.setApplicationEventPublisher(new MockApplicationEventPublisher()); + filter.afterPropertiesSet(); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setRequestURI("x"); + executeFilterInContainerSimulator(new MockFilterConfig(), filter, request, new MockHttpServletResponse(), + new MockFilterChain(true)); + + assertEquals(failedAuth, SecurityContextHolder.getContext().getAuthentication()); + } + //~ Inner Classes ================================================================================================== private class MockFilterChain implements FilterChain {