Add a StreamInput#readArraySize method that ensures sane array sizes (#21697)

Today we read a vint from the stream to allocate the size of an array up-front
before we start reading the values. This can be dangerous if for instance we read
from a corrupted stream or if some manipulated bytes are send for instance from
an attacker or a fuzzer. In most of the cases we can apply some best effort and
validate the array size to be _sane_ by ensuring we can at read at least N bytes
where N is the expected size of the array.
This commit is contained in:
Simon Willnauer 2016-11-21 21:39:21 +01:00 committed by GitHub
parent 0ccf8a742d
commit cb5c25ab4f
9 changed files with 166 additions and 41 deletions

View File

@ -23,6 +23,7 @@ import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefIterator;
import org.elasticsearch.common.io.stream.StreamInput;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
@ -215,6 +216,7 @@ public abstract class BytesReference implements Accountable, Comparable<BytesRef
* that way.
*/
private static final class MarkSupportingStreamInputWrapper extends StreamInput {
// can't use FilterStreamInput it needs to reset the delegate
private final BytesReference reference;
private BytesReferenceStreamInput input;
private int mark = 0;
@ -254,6 +256,11 @@ public abstract class BytesReference implements Accountable, Comparable<BytesRef
return input.available();
}
@Override
protected void ensureCanReadBytes(int length) throws EOFException {
input.ensureCanReadBytes(length);
}
@Override
public void reset() throws IOException {
input = new BytesReferenceStreamInput(reference.iterator(), reference.length());

View File

@ -114,6 +114,14 @@ final class BytesReferenceStreamInput extends StreamInput {
return length - offset;
}
@Override
protected void ensureCanReadBytes(int bytesToRead) throws EOFException {
int bytesAvailable = length - offset;
if (bytesAvailable < bytesToRead) {
throw new EOFException("tried to read: " + bytesToRead + " bytes but only " + bytesAvailable + " remaining");
}
}
@Override
public long skip(long n) throws IOException {
final int skip = (int) Math.min(Integer.MAX_VALUE, n);

View File

@ -86,6 +86,13 @@ public class ByteBufferStreamInput extends StreamInput {
return buffer.remaining();
}
@Override
protected void ensureCanReadBytes(int length) throws EOFException {
if (buffer.remaining() < length) {
throw new EOFException("tried to read: " + length + " bytes but only " + buffer.remaining() + " remaining");
}
}
@Override
public void mark(int readlimit) {
buffer.mark();

View File

@ -21,6 +21,7 @@ package org.elasticsearch.common.io.stream;
import org.elasticsearch.Version;
import java.io.EOFException;
import java.io.IOException;
/**
@ -28,7 +29,7 @@ import java.io.IOException;
*/
public abstract class FilterStreamInput extends StreamInput {
private final StreamInput delegate;
protected final StreamInput delegate;
protected FilterStreamInput(StreamInput delegate) {
this.delegate = delegate;
@ -73,4 +74,9 @@ public abstract class FilterStreamInput extends StreamInput {
public void setVersion(Version version) {
delegate.setVersion(version);
}
@Override
protected void ensureCanReadBytes(int length) throws EOFException {
delegate.ensureCanReadBytes(length);
}
}

View File

@ -95,4 +95,9 @@ public class InputStreamStreamInput extends StreamInput {
public long skip(long n) throws IOException {
return is.skip(n);
}
@Override
protected void ensureCanReadBytes(int length) throws EOFException {
// TODO what can we do here?
}
}

View File

@ -28,7 +28,6 @@ import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.CharsRef;
import org.apache.lucene.util.CharsRefBuilder;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable;
@ -113,7 +112,7 @@ public abstract class StreamInput extends InputStream {
* bytes of the stream.
*/
public BytesReference readBytesReference() throws IOException {
int length = readVInt();
int length = readArraySize();
return readBytesReference(length);
}
@ -145,7 +144,7 @@ public abstract class StreamInput extends InputStream {
}
public BytesRef readBytesRef() throws IOException {
int length = readVInt();
int length = readArraySize();
return readBytesRef(length);
}
@ -332,7 +331,7 @@ public abstract class StreamInput extends InputStream {
public String readString() throws IOException {
// TODO it would be nice to not call readByte() for every character but we don't know how much to read up-front
// we can make the loop much more complicated but that won't buy us much compared to the bounds checks in readByte()
final int charCount = readVInt();
final int charCount = readArraySize();
if (spare.chars.length < charCount) {
// we don't use ArrayUtils.grow since there is no need to copy the array
spare.chars = new char[ArrayUtil.oversize(charCount, Character.BYTES)];
@ -412,7 +411,7 @@ public abstract class StreamInput extends InputStream {
public abstract int available() throws IOException;
public String[] readStringArray() throws IOException {
int size = readVInt();
int size = readArraySize();
if (size == 0) {
return Strings.EMPTY_ARRAY;
}
@ -432,7 +431,7 @@ public abstract class StreamInput extends InputStream {
}
public <K, V> Map<K, V> readMap(Writeable.Reader<K> keyReader, Writeable.Reader<V> valueReader) throws IOException {
int size = readVInt();
int size = readArraySize();
Map<K, V> map = new HashMap<>(size);
for (int i = 0; i < size; i++) {
K key = keyReader.read(this);
@ -454,7 +453,7 @@ public abstract class StreamInput extends InputStream {
*/
public <K, V> Map<K, List<V>> readMapOfLists(final Writeable.Reader<K> keyReader, final Writeable.Reader<V> valueReader)
throws IOException {
final int size = readVInt();
final int size = readArraySize();
if (size == 0) {
return Collections.emptyMap();
}
@ -531,7 +530,7 @@ public abstract class StreamInput extends InputStream {
@SuppressWarnings("unchecked")
private List readArrayList() throws IOException {
int size = readVInt();
int size = readArraySize();
List list = new ArrayList(size);
for (int i = 0; i < size; i++) {
list.add(readGenericValue());
@ -545,7 +544,7 @@ public abstract class StreamInput extends InputStream {
}
private Object[] readArray() throws IOException {
int size8 = readVInt();
int size8 = readArraySize();
Object[] list8 = new Object[size8];
for (int i = 0; i < size8; i++) {
list8[i] = readGenericValue();
@ -554,7 +553,7 @@ public abstract class StreamInput extends InputStream {
}
private Map readLinkedHashMap() throws IOException {
int size9 = readVInt();
int size9 = readArraySize();
Map map9 = new LinkedHashMap(size9);
for (int i = 0; i < size9; i++) {
map9.put(readString(), readGenericValue());
@ -563,7 +562,7 @@ public abstract class StreamInput extends InputStream {
}
private Map readHashMap() throws IOException {
int size10 = readVInt();
int size10 = readArraySize();
Map map10 = new HashMap(size10);
for (int i = 0; i < size10; i++) {
map10.put(readString(), readGenericValue());
@ -600,7 +599,7 @@ public abstract class StreamInput extends InputStream {
}
public int[] readIntArray() throws IOException {
int length = readVInt();
int length = readArraySize();
int[] values = new int[length];
for (int i = 0; i < length; i++) {
values[i] = readInt();
@ -609,7 +608,7 @@ public abstract class StreamInput extends InputStream {
}
public int[] readVIntArray() throws IOException {
int length = readVInt();
int length = readArraySize();
int[] values = new int[length];
for (int i = 0; i < length; i++) {
values[i] = readVInt();
@ -618,7 +617,7 @@ public abstract class StreamInput extends InputStream {
}
public long[] readLongArray() throws IOException {
int length = readVInt();
int length = readArraySize();
long[] values = new long[length];
for (int i = 0; i < length; i++) {
values[i] = readLong();
@ -627,7 +626,7 @@ public abstract class StreamInput extends InputStream {
}
public long[] readVLongArray() throws IOException {
int length = readVInt();
int length = readArraySize();
long[] values = new long[length];
for (int i = 0; i < length; i++) {
values[i] = readVLong();
@ -636,7 +635,7 @@ public abstract class StreamInput extends InputStream {
}
public float[] readFloatArray() throws IOException {
int length = readVInt();
int length = readArraySize();
float[] values = new float[length];
for (int i = 0; i < length; i++) {
values[i] = readFloat();
@ -645,7 +644,7 @@ public abstract class StreamInput extends InputStream {
}
public double[] readDoubleArray() throws IOException {
int length = readVInt();
int length = readArraySize();
double[] values = new double[length];
for (int i = 0; i < length; i++) {
values[i] = readDouble();
@ -654,14 +653,14 @@ public abstract class StreamInput extends InputStream {
}
public byte[] readByteArray() throws IOException {
final int length = readVInt();
final int length = readArraySize();
final byte[] bytes = new byte[length];
readBytes(bytes, 0, bytes.length);
return bytes;
}
public <T> T[] readArray(Writeable.Reader<T> reader, IntFunction<T[]> arraySupplier) throws IOException {
int length = readVInt();
int length = readArraySize();
T[] values = arraySupplier.apply(length);
for (int i = 0; i < length; i++) {
values[i] = reader.read(this);
@ -833,7 +832,7 @@ public abstract class StreamInput extends InputStream {
* @throws IOException if any step fails
*/
public <T extends Streamable> List<T> readStreamableList(Supplier<T> constructor) throws IOException {
int count = readVInt();
int count = readArraySize();
List<T> builder = new ArrayList<>(count);
for (int i=0; i<count; i++) {
T instance = constructor.get();
@ -847,7 +846,7 @@ public abstract class StreamInput extends InputStream {
* Reads a list of objects
*/
public <T> List<T> readList(Writeable.Reader<T> reader) throws IOException {
int count = readVInt();
int count = readArraySize();
List<T> builder = new ArrayList<>(count);
for (int i=0; i<count; i++) {
builder.add(reader.read(this));
@ -859,7 +858,7 @@ public abstract class StreamInput extends InputStream {
* Reads a list of {@link NamedWriteable}s.
*/
public <T extends NamedWriteable> List<T> readNamedWriteableList(Class<T> categoryClass) throws IOException {
int count = readVInt();
int count = readArraySize();
List<T> builder = new ArrayList<>(count);
for (int i=0; i<count; i++) {
builder.add(readNamedWriteable(categoryClass));
@ -875,4 +874,29 @@ public abstract class StreamInput extends InputStream {
return new InputStreamStreamInput(new ByteArrayInputStream(bytes, offset, length));
}
/**
* Reads a vint via {@link #readVInt()} and applies basic checks to ensure the read array size is sane.
* This method uses {@link #ensureCanReadBytes(int)} to ensure this stream has enough bytes to read for the read array size.
*/
private int readArraySize() throws IOException {
final int arraySize = readVInt();
if (arraySize > ArrayUtil.MAX_ARRAY_LENGTH) {
throw new IllegalStateException("array length must be <= to " + ArrayUtil.MAX_ARRAY_LENGTH + " but was: " + arraySize);
}
if (arraySize < 0) {
throw new NegativeArraySizeException("array size must be positive but was: " + arraySize);
}
// lets do a sanity check that if we are reading an array size that is bigger that the remaining bytes we can safely
// throw an exception instead of allocating the array based on the size. A simple corrutpted byte can make a node go OOM
// if the size is large and for perf reasons we allocate arrays ahead of time
ensureCanReadBytes(arraySize);
return arraySize;
}
/**
* This method throws an {@link EOFException} if the given number of bytes can not be read from the this stream. This method might
* be a no-op depending on the underlying implementation if the information of the remaining bytes is not present.
*/
protected abstract void ensureCanReadBytes(int length) throws EOFException;
}

View File

@ -20,8 +20,10 @@
package org.elasticsearch.index.translog;
import org.apache.lucene.store.BufferedChecksum;
import org.elasticsearch.common.io.stream.FilterStreamInput;
import org.elasticsearch.common.io.stream.StreamInput;
import java.io.EOFException;
import java.io.IOException;
import java.util.zip.CRC32;
import java.util.zip.Checksum;
@ -30,19 +32,18 @@ import java.util.zip.Checksum;
* Similar to Lucene's BufferedChecksumIndexInput, however this wraps a
* {@link StreamInput} so anything read will update the checksum
*/
public final class BufferedChecksumStreamInput extends StreamInput {
public final class BufferedChecksumStreamInput extends FilterStreamInput {
private static final int SKIP_BUFFER_SIZE = 1024;
private byte[] skipBuffer;
private final StreamInput in;
private final Checksum digest;
public BufferedChecksumStreamInput(StreamInput in) {
this.in = in;
super(in);
this.digest = new BufferedChecksum(new CRC32());
}
public BufferedChecksumStreamInput(StreamInput in, BufferedChecksumStreamInput reuse) {
this.in = in;
super(in);
if (reuse == null ) {
this.digest = new BufferedChecksum(new CRC32());
} else {
@ -58,20 +59,20 @@ public final class BufferedChecksumStreamInput extends StreamInput {
@Override
public byte readByte() throws IOException {
final byte b = in.readByte();
final byte b = delegate.readByte();
digest.update(b);
return b;
}
@Override
public void readBytes(byte[] b, int offset, int len) throws IOException {
in.readBytes(b, offset, len);
delegate.readBytes(b, offset, len);
digest.update(b, offset, len);
}
@Override
public void reset() throws IOException {
in.reset();
delegate.reset();
digest.reset();
}
@ -80,14 +81,9 @@ public final class BufferedChecksumStreamInput extends StreamInput {
return readByte() & 0xFF;
}
@Override
public void close() throws IOException {
in.close();
}
@Override
public boolean markSupported() {
return in.markSupported();
return delegate.markSupported();
}
@ -109,17 +105,14 @@ public final class BufferedChecksumStreamInput extends StreamInput {
return skipped;
}
@Override
public int available() throws IOException {
return in.available();
}
@Override
public synchronized void mark(int readlimit) {
in.mark(readlimit);
delegate.mark(readlimit);
}
public void resetDigest() {
digest.reset();
}
}

View File

@ -30,6 +30,7 @@ import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.test.ESTestCase;
import org.joda.time.DateTimeZone;
import java.io.EOFException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
@ -696,4 +697,69 @@ public class BytesStreamsTests extends ESTestCase {
}
}
}
public void testReadTooLargeArraySize() throws IOException {
try (BytesStreamOutput output = new BytesStreamOutput(0)) {
output.writeVInt(10);
for (int i = 0; i < 10; i ++) {
output.writeInt(i);
}
output.writeVInt(Integer.MAX_VALUE);
for (int i = 0; i < 10; i ++) {
output.writeInt(i);
}
try (StreamInput streamInput = output.bytes().streamInput()) {
int[] ints = streamInput.readIntArray();
for (int i = 0; i < 10; i ++) {
assertEquals(i, ints[i]);
}
expectThrows(IllegalStateException.class, () -> streamInput.readIntArray());
}
}
}
public void testReadCorruptedArraySize() throws IOException {
try (BytesStreamOutput output = new BytesStreamOutput(0)) {
output.writeVInt(10);
for (int i = 0; i < 10; i ++) {
output.writeInt(i);
}
output.writeVInt(100);
for (int i = 0; i < 10; i ++) {
output.writeInt(i);
}
try (StreamInput streamInput = output.bytes().streamInput()) {
int[] ints = streamInput.readIntArray();
for (int i = 0; i < 10; i ++) {
assertEquals(i, ints[i]);
}
EOFException eofException = expectThrows(EOFException.class, () -> streamInput.readIntArray());
assertEquals("tried to read: 100 bytes but only 40 remaining", eofException.getMessage());
}
}
}
public void testReadNegativeArraySize() throws IOException {
try (BytesStreamOutput output = new BytesStreamOutput(0)) {
output.writeVInt(10);
for (int i = 0; i < 10; i ++) {
output.writeInt(i);
}
output.writeVInt(Integer.MIN_VALUE);
for (int i = 0; i < 10; i ++) {
output.writeInt(i);
}
try (StreamInput streamInput = output.bytes().streamInput()) {
int[] ints = streamInput.readIntArray();
for (int i = 0; i < 10; i ++) {
assertEquals(i, ints[i]);
}
NegativeArraySizeException exception = expectThrows(NegativeArraySizeException.class, () -> streamInput.readIntArray());
assertEquals("array size must be positive but was: -2147483648", exception.getMessage());
}
}
}
}

View File

@ -24,6 +24,7 @@ import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.StreamInput;
import java.io.EOFException;
import java.io.IOException;
/**
@ -67,6 +68,14 @@ class ByteBufStreamInput extends StreamInput {
return endIndex - buffer.readerIndex();
}
@Override
protected void ensureCanReadBytes(int length) throws EOFException {
int bytesAvailable = endIndex - buffer.readerIndex();
if (bytesAvailable < length) {
throw new EOFException("tried to read: " + length + " bytes but only " + bytesAvailable + " remaining");
}
}
@Override
public void mark(int readlimit) {
buffer.markReaderIndex();