Constant Time Comparison for CSRF tokens

Closes gh-9291
This commit is contained in:
Rob Winch 2020-12-17 15:01:28 -06:00
parent c066e23a86
commit 40e027c56d
2 changed files with 42 additions and 2 deletions

View File

@ -17,6 +17,7 @@
package org.springframework.security.web.csrf;
import java.io.IOException;
import java.security.MessageDigest;
import java.util.Arrays;
import java.util.HashSet;
@ -31,6 +32,7 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.core.log.LogMessage;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.crypto.codec.Utf8;
import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.access.AccessDeniedHandlerImpl;
import org.springframework.security.web.util.UrlUtils;
@ -119,7 +121,7 @@ public final class CsrfFilter extends OncePerRequestFilter {
if (actualToken == null) {
actualToken = request.getParameter(csrfToken.getParameterName());
}
if (!csrfToken.getToken().equals(actualToken)) {
if (!equalsConstantTime(csrfToken.getToken(), actualToken)) {
this.logger.debug(
LogMessage.of(() -> "Invalid CSRF token found for " + UrlUtils.buildFullRequestUrl(request)));
AccessDeniedException exception = (!missingToken) ? new InvalidCsrfTokenException(csrfToken, actualToken)
@ -165,6 +167,24 @@ public final class CsrfFilter extends OncePerRequestFilter {
this.accessDeniedHandler = accessDeniedHandler;
}
/**
* Constant time comparison to prevent against timing attacks.
* @param expected
* @param actual
* @return
*/
private static boolean equalsConstantTime(String expected, String actual) {
byte[] expectedBytes = bytesUtf8(expected);
byte[] actualBytes = bytesUtf8(actual);
return MessageDigest.isEqual(expectedBytes, actualBytes);
}
private static byte[] bytesUtf8(String s) {
// need to check if Utf8.encode() runs in constant time (probably not).
// This may leak length of string.
return (s != null) ? Utf8.encode(s) : null;
}
private static final class DefaultRequiresCsrfMatcher implements RequestMatcher {
private final HashSet<String> allowedMethods = new HashSet<>(Arrays.asList("GET", "HEAD", "TRACE", "OPTIONS"));

View File

@ -16,6 +16,7 @@
package org.springframework.security.web.server.csrf;
import java.security.MessageDigest;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
@ -28,6 +29,7 @@ import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.codec.multipart.FormFieldPart;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.security.crypto.codec.Utf8;
import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler;
import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
@ -139,7 +141,7 @@ public class CsrfWebFilter implements WebFilter {
return exchange.getFormData().flatMap((data) -> Mono.justOrEmpty(data.getFirst(expected.getParameterName())))
.switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName())))
.switchIfEmpty(tokenFromMultipartData(exchange, expected))
.map((actual) -> actual.equals(expected.getToken()));
.map((actual) -> equalsConstantTime(actual, expected.getToken()));
}
private Mono<String> tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) {
@ -168,6 +170,24 @@ public class CsrfWebFilter implements WebFilter {
return this.csrfTokenRepository.loadToken(exchange).switchIfEmpty(generateToken(exchange));
}
/**
* Constant time comparison to prevent against timing attacks.
* @param expected
* @param actual
* @return
*/
private static boolean equalsConstantTime(String expected, String actual) {
byte[] expectedBytes = bytesUtf8(expected);
byte[] actualBytes = bytesUtf8(actual);
return MessageDigest.isEqual(expectedBytes, actualBytes);
}
private static byte[] bytesUtf8(String s) {
// need to check if Utf8.encode() runs in constant time (probably not).
// This may leak length of string.
return (s != null) ? Utf8.encode(s) : null;
}
private Mono<CsrfToken> generateToken(ServerWebExchange exchange) {
return this.csrfTokenRepository.generateToken(exchange)
.delayUntil((token) -> this.csrfTokenRepository.saveToken(exchange, token));