diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java index 362168f109..8490c508da 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -174,17 +174,18 @@ public final class CsrfFilter extends OncePerRequestFilter { * @return */ private static boolean equalsConstantTime(String expected, String actual) { - byte[] expectedBytes = bytesUtf8(expected); - byte[] actualBytes = bytesUtf8(actual); + if (expected == actual) { + return true; + } + if (expected == null || actual == null) { + return false; + } + // Encode after ensure that the string is not null + byte[] expectedBytes = Utf8.encode(expected); + byte[] actualBytes = Utf8.encode(actual); return MessageDigest.isEqual(expectedBytes, actualBytes); } - private static byte[] bytesUtf8(String s) { - // need to check if Utf8.encode() runs in constant time (probably not). - // This may leak length of string. - return (s != null) ? Utf8.encode(s) : null; - } - private static final class DefaultRequiresCsrfMatcher implements RequestMatcher { private final HashSet allowedMethods = new HashSet<>(Arrays.asList("GET", "HEAD", "TRACE", "OPTIONS")); diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java index a2699018b3..e856905cf0 100644 --- a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -177,17 +177,18 @@ public class CsrfWebFilter implements WebFilter { * @return */ private static boolean equalsConstantTime(String expected, String actual) { - byte[] expectedBytes = bytesUtf8(expected); - byte[] actualBytes = bytesUtf8(actual); + if (expected == actual) { + return true; + } + if (expected == null || actual == null) { + return false; + } + // Encode after ensure that the string is not null + byte[] expectedBytes = Utf8.encode(expected); + byte[] actualBytes = Utf8.encode(actual); return MessageDigest.isEqual(expectedBytes, actualBytes); } - private static byte[] bytesUtf8(String s) { - // need to check if Utf8.encode() runs in constant time (probably not). - // This may leak length of string. - return (s != null) ? Utf8.encode(s) : null; - } - private Mono generateToken(ServerWebExchange exchange) { return this.csrfTokenRepository.generateToken(exchange) .delayUntil((token) -> this.csrfTokenRepository.saveToken(exchange, token)); diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java index e94eb34a87..3d9e1eaf8e 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.security.web.csrf; import java.io.IOException; +import java.lang.reflect.Method; import java.util.Arrays; import javax.servlet.FilterChain; @@ -96,6 +97,18 @@ public class CsrfFilterTests { this.response = new MockHttpServletResponse(); } + @Test + public void nullConstantTimeEquals() throws Exception { + Method method = CsrfFilter.class.getDeclaredMethod("equalsConstantTime", String.class, String.class); + method.setAccessible(true); + assertThat(method.invoke(CsrfFilter.class, null, null)).isEqualTo(true); + String expectedToken = "Hello—World"; + String actualToken = new String("Hello—World"); + assertThat(method.invoke(CsrfFilter.class, expectedToken, null)).isEqualTo(false); + assertThat(method.invoke(CsrfFilter.class, expectedToken, "hello-world")).isEqualTo(false); + assertThat(method.invoke(CsrfFilter.class, expectedToken, actualToken)).isEqualTo(true); + } + @Test public void constructorNullRepository() { assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter(null)); diff --git a/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java b/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java index 5bae3f58c9..04c3d2715e 100644 --- a/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ package org.springframework.security.web.server.csrf; +import java.lang.reflect.Method; + import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; @@ -65,6 +67,18 @@ public class CsrfWebFilterTests { private MockServerWebExchange post = MockServerWebExchange.from(MockServerHttpRequest.post("/")); + @Test + public void nullConstantTimeEquals() throws Exception { + Method method = CsrfWebFilter.class.getDeclaredMethod("equalsConstantTime", String.class, String.class); + method.setAccessible(true); + assertThat(method.invoke(CsrfWebFilter.class, null, null)).isEqualTo(true); + String expectedToken = "Hello—World"; + String actualToken = new String("Hello—World"); + assertThat(method.invoke(CsrfWebFilter.class, expectedToken, null)).isEqualTo(false); + assertThat(method.invoke(CsrfWebFilter.class, expectedToken, "hello-world")).isEqualTo(false); + assertThat(method.invoke(CsrfWebFilter.class, expectedToken, actualToken)).isEqualTo(true); + } + @Test public void filterWhenGetThenSessionNotCreatedAndChainContinues() { PublisherProbe chainResult = PublisherProbe.empty();