From 99a44a04f7a7bca7f1d487ac38b4897a6b15d72e Mon Sep 17 00:00:00 2001
From: Armin Braun <me@obrown.io>
Date: Thu, 20 Jun 2019 16:38:28 +0200
Subject: [PATCH] Fix Infinite Loops in ExceptionsHelper#unwrap (#42716)
 (#43421)

* Fix Infinite Loops in ExceptionsHelper#unwrap

* Keep track of all seen exceptions and break out on loops
* Closes #42340
---
 .../org/elasticsearch/ExceptionsHelper.java   | 70 ++++++++-----------
 .../elasticsearch/index/engine/Engine.java    |  2 +-
 .../elasticsearch/ExceptionsHelperTests.java  | 39 ++++++++---
 3 files changed, 60 insertions(+), 51 deletions(-)

diff --git a/server/src/main/java/org/elasticsearch/ExceptionsHelper.java b/server/src/main/java/org/elasticsearch/ExceptionsHelper.java
index 94c4a273159..2bc1d590106 100644
--- a/server/src/main/java/org/elasticsearch/ExceptionsHelper.java
+++ b/server/src/main/java/org/elasticsearch/ExceptionsHelper.java
@@ -38,12 +38,14 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashSet;
+import java.util.IdentityHashMap;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
 import java.util.Queue;
 import java.util.Set;
+import java.util.function.Predicate;
 import java.util.stream.Collectors;
 
 public final class ExceptionsHelper {
@@ -185,22 +187,14 @@ public final class ExceptionsHelper {
      * @return Corruption indicating exception if one is found, otherwise {@code null}
      */
     public static IOException unwrapCorruption(Throwable t) {
-        if (t != null) {
-            do {
-                for (Class<?> clazz : CORRUPTION_EXCEPTIONS) {
-                    if (clazz.isInstance(t)) {
-                        return (IOException) t;
-                    }
+        return t == null ? null : ExceptionsHelper.<IOException>unwrapCausesAndSuppressed(t, cause -> {
+            for (Class<?> clazz : CORRUPTION_EXCEPTIONS) {
+                if (clazz.isInstance(cause)) {
+                    return true;
                 }
-                for (Throwable suppressed : t.getSuppressed()) {
-                    IOException corruptionException = unwrapCorruption(suppressed);
-                    if (corruptionException != null) {
-                        return corruptionException;
-                    }
-                }
-            } while ((t = t.getCause()) != null);
-        }
-        return null;
+            }
+            return false;
+        }).orElse(null);
     }
 
     /**
@@ -213,7 +207,11 @@ public final class ExceptionsHelper {
      */
     public static Throwable unwrap(Throwable t, Class<?>... clazzes) {
         if (t != null) {
+            final Set<Throwable> seen = Collections.newSetFromMap(new IdentityHashMap<>());
             do {
+                if (seen.add(t) == false) {
+                    return null;
+                }
                 for (Class<?> clazz : clazzes) {
                     if (clazz.isInstance(t)) {
                         return t;
@@ -246,33 +244,22 @@ public final class ExceptionsHelper {
         return true;
     }
 
-    static final int MAX_ITERATIONS = 1024;
-
-    /**
-     * Unwrap the specified throwable looking for any suppressed errors or errors as a root cause of the specified throwable.
-     *
-     * @param cause the root throwable
-     * @return an optional error if one is found suppressed or a root cause in the tree rooted at the specified throwable
-     */
-    public static Optional<Error> maybeError(final Throwable cause, final Logger logger) {
-        // early terminate if the cause is already an error
-        if (cause instanceof Error) {
-            return Optional.of((Error) cause);
+    @SuppressWarnings("unchecked")
+    private static <T extends Throwable> Optional<T> unwrapCausesAndSuppressed(Throwable cause, Predicate<Throwable> predicate) {
+        if (predicate.test(cause)) {
+            return Optional.of((T) cause);
         }
 
         final Queue<Throwable> queue = new LinkedList<>();
         queue.add(cause);
-        int iterations = 0;
+        final Set<Throwable> seen = Collections.newSetFromMap(new IdentityHashMap<>());
         while (queue.isEmpty() == false) {
-            iterations++;
-            // this is a guard against deeply nested or circular chains of exceptions
-            if (iterations > MAX_ITERATIONS) {
-                logger.warn("giving up looking for fatal errors", cause);
-                break;
-            }
             final Throwable current = queue.remove();
-            if (current instanceof Error) {
-                return Optional.of((Error) current);
+            if (seen.add(current) == false) {
+                continue;
+            }
+            if (predicate.test(current)) {
+                return Optional.of((T) current);
             }
             Collections.addAll(queue, current.getSuppressed());
             if (current.getCause() != null) {
@@ -283,21 +270,24 @@ public final class ExceptionsHelper {
     }
 
     /**
-     * See {@link #maybeError(Throwable, Logger)}. Uses the class-local logger.
+     * Unwrap the specified throwable looking for any suppressed errors or errors as a root cause of the specified throwable.
+     *
+     * @param cause the root throwable
+     * @return an optional error if one is found suppressed or a root cause in the tree rooted at the specified throwable
      */
     public static Optional<Error> maybeError(final Throwable cause) {
-        return maybeError(cause, logger);
+        return unwrapCausesAndSuppressed(cause, t -> t instanceof Error);
     }
 
     /**
      * If the specified cause is an unrecoverable error, this method will rethrow the cause on a separate thread so that it can not be
      * caught and bubbles up to the uncaught exception handler. Note that the cause tree is examined for any {@link Error}. See
-     * {@link #maybeError(Throwable, Logger)} for the semantics.
+     * {@link #maybeError(Throwable)} for the semantics.
      *
      * @param throwable the throwable to possibly throw on another thread
      */
     public static void maybeDieOnAnotherThread(final Throwable throwable) {
-        ExceptionsHelper.maybeError(throwable, logger).ifPresent(error -> {
+        ExceptionsHelper.maybeError(throwable).ifPresent(error -> {
             /*
              * Here be dragons. We want to rethrow this so that it bubbles up to the uncaught exception handler. Yet, sometimes the stack
              * contains statements that catch any throwable (e.g., Netty, and the JDK futures framework). This means that a rethrow here
diff --git a/server/src/main/java/org/elasticsearch/index/engine/Engine.java b/server/src/main/java/org/elasticsearch/index/engine/Engine.java
index ebcdb845d00..e21b816aefd 100644
--- a/server/src/main/java/org/elasticsearch/index/engine/Engine.java
+++ b/server/src/main/java/org/elasticsearch/index/engine/Engine.java
@@ -1142,7 +1142,7 @@ public abstract class Engine implements Closeable {
      */
     @SuppressWarnings("finally")
     private void maybeDie(final String maybeMessage, final Throwable maybeFatal) {
-        ExceptionsHelper.maybeError(maybeFatal, logger).ifPresent(error -> {
+        ExceptionsHelper.maybeError(maybeFatal).ifPresent(error -> {
             try {
                 logger.error(maybeMessage, error);
             } finally {
diff --git a/server/src/test/java/org/elasticsearch/ExceptionsHelperTests.java b/server/src/test/java/org/elasticsearch/ExceptionsHelperTests.java
index 2de2f259e6f..3b5d1ad43da 100644
--- a/server/src/test/java/org/elasticsearch/ExceptionsHelperTests.java
+++ b/server/src/test/java/org/elasticsearch/ExceptionsHelperTests.java
@@ -35,9 +35,9 @@ import org.elasticsearch.search.SearchShardTarget;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.transport.RemoteClusterAware;
 
+import java.io.IOException;
 import java.util.Optional;
 
-import static org.elasticsearch.ExceptionsHelper.MAX_ITERATIONS;
 import static org.elasticsearch.ExceptionsHelper.maybeError;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.CoreMatchers.instanceOf;
@@ -81,20 +81,14 @@ public class ExceptionsHelperTests extends ESTestCase {
         if (fatal) {
             assertError(cause, error);
         } else {
-            assertFalse(maybeError(cause, logger).isPresent());
+            assertFalse(maybeError(cause).isPresent());
         }
 
-        assertFalse(maybeError(new Exception(new DecoderException()), logger).isPresent());
-
-        Throwable chain = outOfMemoryError;
-        for (int i = 0; i < MAX_ITERATIONS; i++) {
-            chain = new Exception(chain);
-        }
-        assertFalse(maybeError(chain, logger).isPresent());
+        assertFalse(maybeError(new Exception(new DecoderException())).isPresent());
     }
 
     private void assertError(final Throwable cause, final Error error) {
-        final Optional<Error> maybeError = maybeError(cause, logger);
+        final Optional<Error> maybeError = maybeError(cause);
         assertTrue(maybeError.isPresent());
         assertThat(maybeError.get(), equalTo(error));
     }
@@ -211,4 +205,29 @@ public class ExceptionsHelperTests extends ESTestCase {
         withSuppressedException.addSuppressed(new RuntimeException());
         assertThat(ExceptionsHelper.unwrapCorruption(withSuppressedException), nullValue());
     }
+
+    public void testSuppressedCycle() {
+        RuntimeException e1 = new RuntimeException();
+        RuntimeException e2 = new RuntimeException();
+        e1.addSuppressed(e2);
+        e2.addSuppressed(e1);
+        ExceptionsHelper.unwrapCorruption(e1);
+
+        final CorruptIndexException corruptIndexException = new CorruptIndexException("corrupt", "resource");
+        RuntimeException e3 = new RuntimeException(corruptIndexException);
+        e3.addSuppressed(e1);
+        assertThat(ExceptionsHelper.unwrapCorruption(e3), equalTo(corruptIndexException));
+
+        RuntimeException e4 = new RuntimeException(e1);
+        e4.addSuppressed(corruptIndexException);
+        assertThat(ExceptionsHelper.unwrapCorruption(e4), equalTo(corruptIndexException));
+    }
+
+    public void testCauseCycle() {
+        RuntimeException e1 = new RuntimeException();
+        RuntimeException e2 = new RuntimeException(e1);
+        e1.initCause(e2);
+        ExceptionsHelper.unwrap(e1, IOException.class);
+        ExceptionsHelper.unwrapCorruption(e1);
+    }
 }