diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java index f508e51599..af6ac23428 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java @@ -51,6 +51,11 @@ public final class CsrfAuthenticationStrategy implements public void onAuthentication(Authentication authentication, HttpServletRequest request, HttpServletResponse response) throws SessionAuthenticationException { - this.csrfTokenRepository.saveToken(null, request, response); + boolean containsToken = this.csrfTokenRepository.loadToken(request) != null; + if(containsToken) { + CsrfToken newToken = this.csrfTokenRepository.generateToken(request); + this.csrfTokenRepository.saveToken(null, request, response); + this.csrfTokenRepository.saveToken(newToken, request, response); + } } } \ No newline at end of file diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java index 77aad4ccdc..866374142f 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java @@ -15,7 +15,14 @@ */ package org.springframework.security.web.csrf; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; import org.junit.Before; import org.junit.Test; @@ -41,11 +48,17 @@ public class CsrfAuthenticationStrategyTests { private CsrfAuthenticationStrategy strategy; + private CsrfToken existingToken; + + private CsrfToken generatedToken; + @Before public void setup() { request = new MockHttpServletRequest(); response = new MockHttpServletResponse(); strategy = new CsrfAuthenticationStrategy(csrfTokenRepository); + existingToken = new DefaultCsrfToken("_csrf", "_csrf", "1"); + generatedToken = new DefaultCsrfToken("_csrf", "_csrf", "2"); } @Test(expected = IllegalArgumentException.class) @@ -54,11 +67,21 @@ public class CsrfAuthenticationStrategyTests { } @Test - public void logoutRemovesCsrfToken() { - strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"),request, response); + public void logoutRemovesCsrfTokenAndSavesNew() { + when(csrfTokenRepository.loadToken(request)).thenReturn(existingToken); + when(csrfTokenRepository.generateToken(request)).thenReturn(generatedToken); + strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), request, response); verify(csrfTokenRepository).saveToken(null, request, response); + // SEC-2404 + verify(csrfTokenRepository).saveToken(eq(generatedToken), eq(request), eq(response)); } + @Test + public void logoutRemovesNoActionIfNullToken() { + strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), request, response); + + verify(csrfTokenRepository,never()).saveToken(any(CsrfToken.class), any(HttpServletRequest.class), any(HttpServletResponse.class)); + } }