diff --git a/core/src/main/java/org/elasticsearch/common/bytes/BytesReference.java b/core/src/main/java/org/elasticsearch/common/bytes/BytesReference.java index f31ea2bbf82..92632ad7874 100644 --- a/core/src/main/java/org/elasticsearch/common/bytes/BytesReference.java +++ b/core/src/main/java/org/elasticsearch/common/bytes/BytesReference.java @@ -23,6 +23,7 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRefIterator; import org.elasticsearch.common.io.stream.StreamInput; +import java.io.EOFException; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -215,6 +216,7 @@ public abstract class BytesReference implements Accountable, Comparable Map readMap(Writeable.Reader keyReader, Writeable.Reader valueReader) throws IOException { - int size = readVInt(); + int size = readArraySize(); Map map = new HashMap<>(size); for (int i = 0; i < size; i++) { K key = keyReader.read(this); @@ -454,7 +453,7 @@ public abstract class StreamInput extends InputStream { */ public Map> readMapOfLists(final Writeable.Reader keyReader, final Writeable.Reader valueReader) throws IOException { - final int size = readVInt(); + final int size = readArraySize(); if (size == 0) { return Collections.emptyMap(); } @@ -531,7 +530,7 @@ public abstract class StreamInput extends InputStream { @SuppressWarnings("unchecked") private List readArrayList() throws IOException { - int size = readVInt(); + int size = readArraySize(); List list = new ArrayList(size); for (int i = 0; i < size; i++) { list.add(readGenericValue()); @@ -545,7 +544,7 @@ public abstract class StreamInput extends InputStream { } private Object[] readArray() throws IOException { - int size8 = readVInt(); + int size8 = readArraySize(); Object[] list8 = new Object[size8]; for (int i = 0; i < size8; i++) { list8[i] = readGenericValue(); @@ -554,7 +553,7 @@ public abstract class StreamInput extends InputStream { } private Map readLinkedHashMap() throws IOException { - int size9 = readVInt(); + int size9 = readArraySize(); Map map9 = new LinkedHashMap(size9); for (int i = 0; i < size9; i++) { map9.put(readString(), readGenericValue()); @@ -563,7 +562,7 @@ public abstract class StreamInput extends InputStream { } private Map readHashMap() throws IOException { - int size10 = readVInt(); + int size10 = readArraySize(); Map map10 = new HashMap(size10); for (int i = 0; i < size10; i++) { map10.put(readString(), readGenericValue()); @@ -600,7 +599,7 @@ public abstract class StreamInput extends InputStream { } public int[] readIntArray() throws IOException { - int length = readVInt(); + int length = readArraySize(); int[] values = new int[length]; for (int i = 0; i < length; i++) { values[i] = readInt(); @@ -609,7 +608,7 @@ public abstract class StreamInput extends InputStream { } public int[] readVIntArray() throws IOException { - int length = readVInt(); + int length = readArraySize(); int[] values = new int[length]; for (int i = 0; i < length; i++) { values[i] = readVInt(); @@ -618,7 +617,7 @@ public abstract class StreamInput extends InputStream { } public long[] readLongArray() throws IOException { - int length = readVInt(); + int length = readArraySize(); long[] values = new long[length]; for (int i = 0; i < length; i++) { values[i] = readLong(); @@ -627,7 +626,7 @@ public abstract class StreamInput extends InputStream { } public long[] readVLongArray() throws IOException { - int length = readVInt(); + int length = readArraySize(); long[] values = new long[length]; for (int i = 0; i < length; i++) { values[i] = readVLong(); @@ -636,7 +635,7 @@ public abstract class StreamInput extends InputStream { } public float[] readFloatArray() throws IOException { - int length = readVInt(); + int length = readArraySize(); float[] values = new float[length]; for (int i = 0; i < length; i++) { values[i] = readFloat(); @@ -645,7 +644,7 @@ public abstract class StreamInput extends InputStream { } public double[] readDoubleArray() throws IOException { - int length = readVInt(); + int length = readArraySize(); double[] values = new double[length]; for (int i = 0; i < length; i++) { values[i] = readDouble(); @@ -654,14 +653,14 @@ public abstract class StreamInput extends InputStream { } public byte[] readByteArray() throws IOException { - final int length = readVInt(); + final int length = readArraySize(); final byte[] bytes = new byte[length]; readBytes(bytes, 0, bytes.length); return bytes; } public T[] readArray(Writeable.Reader reader, IntFunction arraySupplier) throws IOException { - int length = readVInt(); + int length = readArraySize(); T[] values = arraySupplier.apply(length); for (int i = 0; i < length; i++) { values[i] = reader.read(this); @@ -833,7 +832,7 @@ public abstract class StreamInput extends InputStream { * @throws IOException if any step fails */ public List readStreamableList(Supplier constructor) throws IOException { - int count = readVInt(); + int count = readArraySize(); List builder = new ArrayList<>(count); for (int i=0; i List readList(Writeable.Reader reader) throws IOException { - int count = readVInt(); + int count = readArraySize(); List builder = new ArrayList<>(count); for (int i=0; i List readNamedWriteableList(Class categoryClass) throws IOException { - int count = readVInt(); + int count = readArraySize(); List builder = new ArrayList<>(count); for (int i=0; i ArrayUtil.MAX_ARRAY_LENGTH) { + throw new IllegalStateException("array length must be <= to " + ArrayUtil.MAX_ARRAY_LENGTH + " but was: " + arraySize); + } + if (arraySize < 0) { + throw new NegativeArraySizeException("array size must be positive but was: " + arraySize); + } + // lets do a sanity check that if we are reading an array size that is bigger that the remaining bytes we can safely + // throw an exception instead of allocating the array based on the size. A simple corrutpted byte can make a node go OOM + // if the size is large and for perf reasons we allocate arrays ahead of time + ensureCanReadBytes(arraySize); + return arraySize; + } + + /** + * This method throws an {@link EOFException} if the given number of bytes can not be read from the this stream. This method might + * be a no-op depending on the underlying implementation if the information of the remaining bytes is not present. + */ + protected abstract void ensureCanReadBytes(int length) throws EOFException; + } diff --git a/core/src/main/java/org/elasticsearch/index/translog/BufferedChecksumStreamInput.java b/core/src/main/java/org/elasticsearch/index/translog/BufferedChecksumStreamInput.java index 58aa60a23c8..ba6da4ba522 100644 --- a/core/src/main/java/org/elasticsearch/index/translog/BufferedChecksumStreamInput.java +++ b/core/src/main/java/org/elasticsearch/index/translog/BufferedChecksumStreamInput.java @@ -20,8 +20,10 @@ package org.elasticsearch.index.translog; import org.apache.lucene.store.BufferedChecksum; +import org.elasticsearch.common.io.stream.FilterStreamInput; import org.elasticsearch.common.io.stream.StreamInput; +import java.io.EOFException; import java.io.IOException; import java.util.zip.CRC32; import java.util.zip.Checksum; @@ -30,19 +32,18 @@ import java.util.zip.Checksum; * Similar to Lucene's BufferedChecksumIndexInput, however this wraps a * {@link StreamInput} so anything read will update the checksum */ -public final class BufferedChecksumStreamInput extends StreamInput { +public final class BufferedChecksumStreamInput extends FilterStreamInput { private static final int SKIP_BUFFER_SIZE = 1024; private byte[] skipBuffer; - private final StreamInput in; private final Checksum digest; public BufferedChecksumStreamInput(StreamInput in) { - this.in = in; + super(in); this.digest = new BufferedChecksum(new CRC32()); } public BufferedChecksumStreamInput(StreamInput in, BufferedChecksumStreamInput reuse) { - this.in = in; + super(in); if (reuse == null ) { this.digest = new BufferedChecksum(new CRC32()); } else { @@ -58,20 +59,20 @@ public final class BufferedChecksumStreamInput extends StreamInput { @Override public byte readByte() throws IOException { - final byte b = in.readByte(); + final byte b = delegate.readByte(); digest.update(b); return b; } @Override public void readBytes(byte[] b, int offset, int len) throws IOException { - in.readBytes(b, offset, len); + delegate.readBytes(b, offset, len); digest.update(b, offset, len); } @Override public void reset() throws IOException { - in.reset(); + delegate.reset(); digest.reset(); } @@ -80,14 +81,9 @@ public final class BufferedChecksumStreamInput extends StreamInput { return readByte() & 0xFF; } - @Override - public void close() throws IOException { - in.close(); - } - @Override public boolean markSupported() { - return in.markSupported(); + return delegate.markSupported(); } @@ -109,17 +105,14 @@ public final class BufferedChecksumStreamInput extends StreamInput { return skipped; } - @Override - public int available() throws IOException { - return in.available(); - } @Override public synchronized void mark(int readlimit) { - in.mark(readlimit); + delegate.mark(readlimit); } public void resetDigest() { digest.reset(); } + } diff --git a/core/src/test/java/org/elasticsearch/common/io/stream/BytesStreamsTests.java b/core/src/test/java/org/elasticsearch/common/io/stream/BytesStreamsTests.java index e9958c1c516..866a02476e7 100644 --- a/core/src/test/java/org/elasticsearch/common/io/stream/BytesStreamsTests.java +++ b/core/src/test/java/org/elasticsearch/common/io/stream/BytesStreamsTests.java @@ -30,6 +30,7 @@ import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.test.ESTestCase; import org.joda.time.DateTimeZone; +import java.io.EOFException; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; @@ -696,4 +697,69 @@ public class BytesStreamsTests extends ESTestCase { } } } + + public void testReadTooLargeArraySize() throws IOException { + try (BytesStreamOutput output = new BytesStreamOutput(0)) { + output.writeVInt(10); + for (int i = 0; i < 10; i ++) { + output.writeInt(i); + } + + output.writeVInt(Integer.MAX_VALUE); + for (int i = 0; i < 10; i ++) { + output.writeInt(i); + } + try (StreamInput streamInput = output.bytes().streamInput()) { + int[] ints = streamInput.readIntArray(); + for (int i = 0; i < 10; i ++) { + assertEquals(i, ints[i]); + } + expectThrows(IllegalStateException.class, () -> streamInput.readIntArray()); + } + } + } + + public void testReadCorruptedArraySize() throws IOException { + try (BytesStreamOutput output = new BytesStreamOutput(0)) { + output.writeVInt(10); + for (int i = 0; i < 10; i ++) { + output.writeInt(i); + } + + output.writeVInt(100); + for (int i = 0; i < 10; i ++) { + output.writeInt(i); + } + try (StreamInput streamInput = output.bytes().streamInput()) { + int[] ints = streamInput.readIntArray(); + for (int i = 0; i < 10; i ++) { + assertEquals(i, ints[i]); + } + EOFException eofException = expectThrows(EOFException.class, () -> streamInput.readIntArray()); + assertEquals("tried to read: 100 bytes but only 40 remaining", eofException.getMessage()); + } + } + } + + public void testReadNegativeArraySize() throws IOException { + try (BytesStreamOutput output = new BytesStreamOutput(0)) { + output.writeVInt(10); + for (int i = 0; i < 10; i ++) { + output.writeInt(i); + } + + output.writeVInt(Integer.MIN_VALUE); + for (int i = 0; i < 10; i ++) { + output.writeInt(i); + } + try (StreamInput streamInput = output.bytes().streamInput()) { + int[] ints = streamInput.readIntArray(); + for (int i = 0; i < 10; i ++) { + assertEquals(i, ints[i]); + } + NegativeArraySizeException exception = expectThrows(NegativeArraySizeException.class, () -> streamInput.readIntArray()); + assertEquals("array size must be positive but was: -2147483648", exception.getMessage()); + } + } + } } diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/ByteBufStreamInput.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/ByteBufStreamInput.java index 2219ce31ff6..45aa029b46f 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/ByteBufStreamInput.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/ByteBufStreamInput.java @@ -24,6 +24,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; +import java.io.EOFException; import java.io.IOException; /** @@ -67,6 +68,14 @@ class ByteBufStreamInput extends StreamInput { return endIndex - buffer.readerIndex(); } + @Override + protected void ensureCanReadBytes(int length) throws EOFException { + int bytesAvailable = endIndex - buffer.readerIndex(); + if (bytesAvailable < length) { + throw new EOFException("tried to read: " + length + " bytes but only " + bytesAvailable + " remaining"); + } + } + @Override public void mark(int readlimit) { buffer.markReaderIndex();