Avoid StackOverflowError if write circular reference exception (#54147)

We should never write a circular reference exception as we will fail a 
node with StackOverflowError. However, we have one in #53589. 
I tried but failed to find its location. With this commit, we will avoid 
StackOverflowError in production and detect circular exceptions in
tests.

Closes #53589
This commit is contained in:
Nhat Nguyen 2020-03-25 09:21:51 -04:00
parent 05c5529b2d
commit 4ecc7dcca5
3 changed files with 45 additions and 5 deletions

View File

@ -276,7 +276,7 @@ public class ElasticsearchException extends RuntimeException implements ToXConte
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(this.getMessage()); out.writeOptionalString(this.getMessage());
out.writeException(this.getCause()); out.writeException(this.getCause());
writeStackTraces(this, out); writeStackTraces(this, out, StreamOutput::writeException);
out.writeMapOfLists(headers, StreamOutput::writeString, StreamOutput::writeString); out.writeMapOfLists(headers, StreamOutput::writeString, StreamOutput::writeString);
out.writeMapOfLists(metadata, StreamOutput::writeString, StreamOutput::writeString); out.writeMapOfLists(metadata, StreamOutput::writeString, StreamOutput::writeString);
} }
@ -715,7 +715,8 @@ public class ElasticsearchException extends RuntimeException implements ToXConte
/** /**
* Serializes the given exceptions stacktrace elements as well as it's suppressed exceptions to the given output stream. * Serializes the given exceptions stacktrace elements as well as it's suppressed exceptions to the given output stream.
*/ */
public static <T extends Throwable> T writeStackTraces(T throwable, StreamOutput out) throws IOException { public static <T extends Throwable> T writeStackTraces(T throwable, StreamOutput out,
Writer<Throwable> exceptionWriter) throws IOException {
StackTraceElement[] stackTrace = throwable.getStackTrace(); StackTraceElement[] stackTrace = throwable.getStackTrace();
out.writeVInt(stackTrace.length); out.writeVInt(stackTrace.length);
for (StackTraceElement element : stackTrace) { for (StackTraceElement element : stackTrace) {
@ -727,7 +728,7 @@ public class ElasticsearchException extends RuntimeException implements ToXConte
Throwable[] suppressed = throwable.getSuppressed(); Throwable[] suppressed = throwable.getSuppressed();
out.writeVInt(suppressed.length); out.writeVInt(suppressed.length);
for (Throwable t : suppressed) { for (Throwable t : suppressed) {
out.writeException(t); exceptionWriter.write(out, t);
} }
return throwable; return throwable;
} }

View File

@ -90,6 +90,7 @@ import java.util.function.IntFunction;
public abstract class StreamOutput extends OutputStream { public abstract class StreamOutput extends OutputStream {
private static final Map<TimeUnit, Byte> TIME_UNIT_BYTE_MAP; private static final Map<TimeUnit, Byte> TIME_UNIT_BYTE_MAP;
private static final int MAX_NESTED_EXCEPTION_LEVEL = 100;
static { static {
final Map<TimeUnit, Byte> timeUnitByteMap = new EnumMap<>(TimeUnit.class); final Map<TimeUnit, Byte> timeUnitByteMap = new EnumMap<>(TimeUnit.class);
@ -910,8 +911,15 @@ public abstract class StreamOutput extends OutputStream {
} }
public void writeException(Throwable throwable) throws IOException { public void writeException(Throwable throwable) throws IOException {
writeException(throwable, throwable, 0);
}
private void writeException(Throwable rootException, Throwable throwable, int nestedLevel) throws IOException {
if (throwable == null) { if (throwable == null) {
writeBoolean(false); writeBoolean(false);
} else if (nestedLevel > MAX_NESTED_EXCEPTION_LEVEL) {
assert failOnTooManyNestedExceptions(rootException);
writeException(new IllegalStateException("too many nested exceptions"));
} else { } else {
writeBoolean(true); writeBoolean(true);
boolean writeCause = true; boolean writeCause = true;
@ -1020,12 +1028,16 @@ public abstract class StreamOutput extends OutputStream {
writeOptionalString(throwable.getMessage()); writeOptionalString(throwable.getMessage());
} }
if (writeCause) { if (writeCause) {
writeException(throwable.getCause()); writeException(rootException, throwable.getCause(), nestedLevel + 1);
} }
ElasticsearchException.writeStackTraces(throwable, this); ElasticsearchException.writeStackTraces(throwable, this, (o, t) -> o.writeException(rootException, t, nestedLevel + 1));
} }
} }
boolean failOnTooManyNestedExceptions(Throwable throwable) {
throw new AssertionError("too many nested exceptions", throwable);
}
/** /**
* Writes a {@link NamedWriteable} to the current stream, by first writing its name and then the object itself * Writes a {@link NamedWriteable} to the current stream, by first writing its name and then the object itself
*/ */

View File

@ -19,6 +19,7 @@
package org.elasticsearch.common.io.stream; package org.elasticsearch.common.io.stream;
import org.apache.lucene.store.AlreadyClosedException;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.Constants; import org.apache.lucene.util.Constants;
import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesArray;
@ -49,12 +50,14 @@ import java.util.stream.Collectors;
import java.util.stream.IntStream; import java.util.stream.IntStream;
import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.endsWith; import static org.hamcrest.Matchers.endsWith;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.sameInstance;
/** /**
* Tests for {@link BytesStreamOutput} paging behaviour. * Tests for {@link BytesStreamOutput} paging behaviour.
@ -850,4 +853,28 @@ public class BytesStreamsTests extends ESTestCase {
assertEqualityAfterSerialize(timeValue, 1 + out.bytes().length()); assertEqualityAfterSerialize(timeValue, 1 + out.bytes().length());
} }
public void testWriteCircularReferenceException() throws IOException {
IOException rootEx = new IOException("disk broken");
AlreadyClosedException ace = new AlreadyClosedException("closed", rootEx);
rootEx.addSuppressed(ace); // circular reference
BytesStreamOutput testOut = new BytesStreamOutput();
AssertionError error = expectThrows(AssertionError.class, () -> testOut.writeException(rootEx));
assertThat(error.getMessage(), containsString("too many nested exceptions"));
assertThat(error.getCause(), equalTo(rootEx));
BytesStreamOutput prodOut = new BytesStreamOutput() {
@Override
boolean failOnTooManyNestedExceptions(Throwable throwable) {
assertThat(throwable, sameInstance(rootEx));
return true;
}
};
prodOut.writeException(rootEx);
StreamInput in = prodOut.bytes().streamInput();
Exception newEx = in.readException();
assertThat(newEx, instanceOf(IOException.class));
assertThat(newEx.getMessage(), equalTo("disk broken"));
assertArrayEquals(newEx.getStackTrace(), rootEx.getStackTrace());
}
} }