diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java index bbbca0bfc0..362168f109 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java @@ -17,6 +17,7 @@ package org.springframework.security.web.csrf; import java.io.IOException; +import java.security.MessageDigest; import java.util.Arrays; import java.util.HashSet; @@ -31,6 +32,7 @@ import org.apache.commons.logging.LogFactory; import org.springframework.core.log.LogMessage; import org.springframework.security.access.AccessDeniedException; +import org.springframework.security.crypto.codec.Utf8; import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.access.AccessDeniedHandlerImpl; import org.springframework.security.web.util.UrlUtils; @@ -119,7 +121,7 @@ public final class CsrfFilter extends OncePerRequestFilter { if (actualToken == null) { actualToken = request.getParameter(csrfToken.getParameterName()); } - if (!csrfToken.getToken().equals(actualToken)) { + if (!equalsConstantTime(csrfToken.getToken(), actualToken)) { this.logger.debug( LogMessage.of(() -> "Invalid CSRF token found for " + UrlUtils.buildFullRequestUrl(request))); AccessDeniedException exception = (!missingToken) ? new InvalidCsrfTokenException(csrfToken, actualToken) @@ -165,6 +167,24 @@ public final class CsrfFilter extends OncePerRequestFilter { this.accessDeniedHandler = accessDeniedHandler; } + /** + * Constant time comparison to prevent against timing attacks. + * @param expected + * @param actual + * @return + */ + private static boolean equalsConstantTime(String expected, String actual) { + byte[] expectedBytes = bytesUtf8(expected); + byte[] actualBytes = bytesUtf8(actual); + return MessageDigest.isEqual(expectedBytes, actualBytes); + } + + private static byte[] bytesUtf8(String s) { + // need to check if Utf8.encode() runs in constant time (probably not). + // This may leak length of string. + return (s != null) ? Utf8.encode(s) : null; + } + private static final class DefaultRequiresCsrfMatcher implements RequestMatcher { private final HashSet allowedMethods = new HashSet<>(Arrays.asList("GET", "HEAD", "TRACE", "OPTIONS")); diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java index 46ffb2cafb..d789da4b56 100644 --- a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java @@ -16,6 +16,7 @@ package org.springframework.security.web.server.csrf; +import java.security.MessageDigest; import java.util.Arrays; import java.util.HashSet; import java.util.Set; @@ -28,6 +29,7 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.codec.multipart.FormFieldPart; import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.security.crypto.codec.Utf8; import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler; import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; @@ -139,7 +141,7 @@ public class CsrfWebFilter implements WebFilter { return exchange.getFormData().flatMap((data) -> Mono.justOrEmpty(data.getFirst(expected.getParameterName()))) .switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName()))) .switchIfEmpty(tokenFromMultipartData(exchange, expected)) - .map((actual) -> actual.equals(expected.getToken())); + .map((actual) -> equalsConstantTime(actual, expected.getToken())); } private Mono tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) { @@ -168,6 +170,24 @@ public class CsrfWebFilter implements WebFilter { return this.csrfTokenRepository.loadToken(exchange).switchIfEmpty(generateToken(exchange)); } + /** + * Constant time comparison to prevent against timing attacks. + * @param expected + * @param actual + * @return + */ + private static boolean equalsConstantTime(String expected, String actual) { + byte[] expectedBytes = bytesUtf8(expected); + byte[] actualBytes = bytesUtf8(actual); + return MessageDigest.isEqual(expectedBytes, actualBytes); + } + + private static byte[] bytesUtf8(String s) { + // need to check if Utf8.encode() runs in constant time (probably not). + // This may leak length of string. + return (s != null) ? Utf8.encode(s) : null; + } + private Mono generateToken(ServerWebExchange exchange) { return this.csrfTokenRepository.generateToken(exchange) .delayUntil((token) -> this.csrfTokenRepository.saveToken(exchange, token));