From 76d9ef4ec35be6ee595b9dbf9157b70261ea59ab Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Tue, 24 Feb 2015 17:29:55 -0600 Subject: [PATCH] SEC-2872: CsrfAuthenticationStrategy Delay Saving CsrfToken --- .../web/csrf/CsrfAuthenticationStrategy.java | 88 ++++++++++++++++++- .../csrf/CsrfAuthenticationStrategyTests.java | 16 +++- 2 files changed, 99 insertions(+), 5 deletions(-) diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java index a1d0170380..4240038f2e 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java @@ -53,11 +53,91 @@ public final class CsrfAuthenticationStrategy implements throws SessionAuthenticationException { boolean containsToken = this.csrfTokenRepository.loadToken(request) != null; if(containsToken) { - CsrfToken newToken = this.csrfTokenRepository.generateToken(request); this.csrfTokenRepository.saveToken(null, request, response); - this.csrfTokenRepository.saveToken(newToken, request, response); - request.setAttribute(CsrfToken.class.getName(), newToken); - request.setAttribute(newToken.getParameterName(), newToken); + + CsrfToken newToken = this.csrfTokenRepository.generateToken(request); + CsrfToken tokenForRequest = new SaveOnAccessCsrfToken(csrfTokenRepository, request, response, newToken); + + request.setAttribute(CsrfToken.class.getName(), tokenForRequest); + request.setAttribute(newToken.getParameterName(), tokenForRequest); } } + + private static final class SaveOnAccessCsrfToken implements CsrfToken { + private transient CsrfTokenRepository tokenRepository; + private transient HttpServletRequest request; + private transient HttpServletResponse response; + + private final CsrfToken delegate; + + public SaveOnAccessCsrfToken(CsrfTokenRepository tokenRepository, + HttpServletRequest request, HttpServletResponse response, + CsrfToken delegate) { + super(); + this.tokenRepository = tokenRepository; + this.request = request; + this.response = response; + this.delegate = delegate; + } + + public String getHeaderName() { + return delegate.getHeaderName(); + } + + public String getParameterName() { + return delegate.getParameterName(); + } + + public String getToken() { + saveTokenIfNecessary(); + return delegate.getToken(); + } + + @Override + public String toString() { + return "SaveOnAccessCsrfToken [delegate=" + delegate + "]"; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + + ((delegate == null) ? 0 : delegate.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + SaveOnAccessCsrfToken other = (SaveOnAccessCsrfToken) obj; + if (delegate == null) { + if (other.delegate != null) + return false; + } else if (!delegate.equals(other.delegate)) + return false; + return true; + } + + private void saveTokenIfNecessary() { + if(this.tokenRepository == null) { + return; + } + + synchronized(this) { + if(tokenRepository != null) { + this.tokenRepository.saveToken(delegate, request, response); + this.tokenRepository = null; + this.request = null; + this.response = null; + } + } + } + + } } diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java index 7bf105ad8a..8b087ed697 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java @@ -15,6 +15,7 @@ */ package org.springframework.security.web.csrf; +import static org.fest.assertions.Assertions.assertThat; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.never; @@ -73,7 +74,7 @@ public class CsrfAuthenticationStrategyTests { strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), request, response); verify(csrfTokenRepository).saveToken(null, request, response); - verify(csrfTokenRepository).saveToken(eq(generatedToken), eq(request), eq(response)); + verify(csrfTokenRepository,never()).saveToken(eq(generatedToken), any(HttpServletRequest.class), any(HttpServletResponse.class)); // SEC-2404, SEC-2832 CsrfToken tokenInRequest = (CsrfToken) request.getAttribute(CsrfToken.class.getName()); assertThat(tokenInRequest.getToken()).isSameAs(generatedToken.getToken()); @@ -82,6 +83,19 @@ public class CsrfAuthenticationStrategyTests { assertThat(request.getAttribute(generatedToken.getParameterName())).isSameAs(tokenInRequest); } + // SEC-2872 + @Test + public void delaySavingCsrf() { + when(csrfTokenRepository.loadToken(request)).thenReturn(existingToken); + when(csrfTokenRepository.generateToken(request)).thenReturn(generatedToken); + strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), request, response); + + verify(csrfTokenRepository).saveToken(null, request, response); + verify(csrfTokenRepository,never()).saveToken(eq(generatedToken), any(HttpServletRequest.class), any(HttpServletResponse.class)); + + CsrfToken tokenInRequest = (CsrfToken) request.getAttribute(CsrfToken.class.getName()); + tokenInRequest.getToken(); + verify(csrfTokenRepository).saveToken(eq(generatedToken), any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test