Add null check in CsrfFilter and CsrfWebFilter

Solve the problem that CsrfFilter and CsrfWebFilter
throws NPE exception when comparing two byte array
is equal in low JDK version.

When JDK version is lower than 1.8.0_45, method
java.security.MessageDigest#isEqual does not verify
whether the two arrays are null. And the above two
class call this method without null judgment.

ZiQiang Zhao<1694392889@qq.com>

Closes gh-9561
This commit is contained in:
佚名 2021-04-06 19:19:39 +08:00 committed by Josh Cummings
parent 2009b5faf0
commit 22d7043d01
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
4 changed files with 49 additions and 20 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2013 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -174,15 +174,16 @@ public final class CsrfFilter extends OncePerRequestFilter {
* @return * @return
*/ */
private static boolean equalsConstantTime(String expected, String actual) { private static boolean equalsConstantTime(String expected, String actual) {
byte[] expectedBytes = bytesUtf8(expected); if (expected == actual) {
byte[] actualBytes = bytesUtf8(actual); return true;
return MessageDigest.isEqual(expectedBytes, actualBytes);
} }
if (expected == null || actual == null) {
private static byte[] bytesUtf8(String s) { return false;
// need to check if Utf8.encode() runs in constant time (probably not). }
// This may leak length of string. // Encode after ensure that the string is not null
return (s != null) ? Utf8.encode(s) : null; byte[] expectedBytes = Utf8.encode(expected);
byte[] actualBytes = Utf8.encode(actual);
return MessageDigest.isEqual(expectedBytes, actualBytes);
} }
private static final class DefaultRequiresCsrfMatcher implements RequestMatcher { private static final class DefaultRequiresCsrfMatcher implements RequestMatcher {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -177,15 +177,16 @@ public class CsrfWebFilter implements WebFilter {
* @return * @return
*/ */
private static boolean equalsConstantTime(String expected, String actual) { private static boolean equalsConstantTime(String expected, String actual) {
byte[] expectedBytes = bytesUtf8(expected); if (expected == actual) {
byte[] actualBytes = bytesUtf8(actual); return true;
return MessageDigest.isEqual(expectedBytes, actualBytes);
} }
if (expected == null || actual == null) {
private static byte[] bytesUtf8(String s) { return false;
// need to check if Utf8.encode() runs in constant time (probably not). }
// This may leak length of string. // Encode after ensure that the string is not null
return (s != null) ? Utf8.encode(s) : null; byte[] expectedBytes = Utf8.encode(expected);
byte[] actualBytes = Utf8.encode(actual);
return MessageDigest.isEqual(expectedBytes, actualBytes);
} }
private Mono<CsrfToken> generateToken(ServerWebExchange exchange) { private Mono<CsrfToken> generateToken(ServerWebExchange exchange) {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2013 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,6 +16,7 @@
package org.springframework.security.web.csrf; package org.springframework.security.web.csrf;
import java.io.IOException; import java.io.IOException;
import java.lang.reflect.Method;
import java.util.Arrays; import java.util.Arrays;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
@ -89,6 +90,18 @@ public class CsrfFilterTests {
this.response = new MockHttpServletResponse(); this.response = new MockHttpServletResponse();
} }
@Test
public void nullConstantTimeEquals() throws Exception {
Method method = CsrfFilter.class.getDeclaredMethod("equalsConstantTime", String.class, String.class);
method.setAccessible(true);
assertThat(method.invoke(CsrfFilter.class, null, null)).isEqualTo(true);
String expectedToken = "Hello—World";
String actualToken = new String("Hello—World");
assertThat(method.invoke(CsrfFilter.class, expectedToken, null)).isEqualTo(false);
assertThat(method.invoke(CsrfFilter.class, expectedToken, "hello-world")).isEqualTo(false);
assertThat(method.invoke(CsrfFilter.class, expectedToken, actualToken)).isEqualTo(true);
}
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void constructorNullRepository() { public void constructorNullRepository() {
new CsrfFilter(null); new CsrfFilter(null);

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,6 +16,8 @@
package org.springframework.security.web.server.csrf; package org.springframework.security.web.server.csrf;
import java.lang.reflect.Method;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.Mock; import org.mockito.Mock;
@ -67,6 +69,18 @@ public class CsrfWebFilterTests {
private MockServerWebExchange post = MockServerWebExchange.from(MockServerHttpRequest.post("/")); private MockServerWebExchange post = MockServerWebExchange.from(MockServerHttpRequest.post("/"));
@Test
public void nullConstantTimeEquals() throws Exception {
Method method = CsrfWebFilter.class.getDeclaredMethod("equalsConstantTime", String.class, String.class);
method.setAccessible(true);
assertThat(method.invoke(CsrfWebFilter.class, null, null)).isEqualTo(true);
String expectedToken = "Hello—World";
String actualToken = new String("Hello—World");
assertThat(method.invoke(CsrfWebFilter.class, expectedToken, null)).isEqualTo(false);
assertThat(method.invoke(CsrfWebFilter.class, expectedToken, "hello-world")).isEqualTo(false);
assertThat(method.invoke(CsrfWebFilter.class, expectedToken, actualToken)).isEqualTo(true);
}
@Test @Test
public void filterWhenGetThenSessionNotCreatedAndChainContinues() { public void filterWhenGetThenSessionNotCreatedAndChainContinues() {
PublisherProbe<Void> chainResult = PublisherProbe.empty(); PublisherProbe<Void> chainResult = PublisherProbe.empty();