Avoid corruption when deserializing booleans

Today we write 0x00 or 0x01 for false or true when serializing a boolean
(and 0x02 for null when serializing an optional boolean) but we
deserialize any non-zero byte to true (except when deserializing an
optional boolean in which case we deserialize 0x02 to null, 0x01 to
true, and any other non-zero byte to false). This too easily allows
corruption into the stream. Instead, we should mark the stream as
corrupted and stop deserializing. This catches when we try to
deserialize something as a boolean that is not a boolean.

Relates #22152
This commit is contained in:
Jason Tedor 2016-12-13 20:10:05 -05:00 committed by GitHub
parent 510ad7b9c7
commit 0d195f1afa
3 changed files with 89 additions and 8 deletions

View File

@ -59,6 +59,7 @@ import java.util.Date;
import java.util.HashMap; import java.util.HashMap;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.function.IntFunction; import java.util.function.IntFunction;
import java.util.function.Supplier; import java.util.function.Supplier;
@ -386,19 +387,28 @@ public abstract class StreamInput extends InputStream {
* Reads a boolean. * Reads a boolean.
*/ */
public final boolean readBoolean() throws IOException { public final boolean readBoolean() throws IOException {
return readByte() != 0; return readBoolean(readByte());
}
private boolean readBoolean(final byte value) {
if (value == 0) {
return false;
} else if (value == 1) {
return true;
} else {
final String message = String.format(Locale.ROOT, "unexpected byte [0x%02x]", value);
throw new IllegalStateException(message);
}
} }
@Nullable @Nullable
public final Boolean readOptionalBoolean() throws IOException { public final Boolean readOptionalBoolean() throws IOException {
byte val = readByte(); final byte value = readByte();
if (val == 2) { if (value == 2) {
return null; return null;
} else {
return readBoolean(value);
} }
if (val == 1) {
return true;
}
return false;
} }
/** /**

View File

@ -368,7 +368,7 @@ public abstract class StreamOutput extends OutputStream {
if (b == null) { if (b == null) {
writeByte(TWO); writeByte(TWO);
} else { } else {
writeByte(b ? ONE : ZERO); writeBoolean(b);
} }
} }

View File

@ -19,6 +19,7 @@
package org.elasticsearch.common.io.stream; package org.elasticsearch.common.io.stream;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.collect.Tuple;
@ -31,11 +32,80 @@ import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasToString;
public class StreamTests extends ESTestCase { public class StreamTests extends ESTestCase {
public void testBooleanSerialization() throws IOException {
final BytesStreamOutput output = new BytesStreamOutput();
output.writeBoolean(false);
output.writeBoolean(true);
final BytesReference bytesReference = output.bytes();
final BytesRef bytesRef = bytesReference.toBytesRef();
assertThat(bytesRef.length, equalTo(2));
final byte[] bytes = bytesRef.bytes;
assertThat(bytes[0], equalTo((byte) 0));
assertThat(bytes[1], equalTo((byte) 1));
final StreamInput input = bytesReference.streamInput();
assertFalse(input.readBoolean());
assertTrue(input.readBoolean());
final Set<Byte> set = IntStream.range(Byte.MIN_VALUE, Byte.MAX_VALUE).mapToObj(v -> (byte) v).collect(Collectors.toSet());
set.remove((byte) 0);
set.remove((byte) 1);
final byte[] corruptBytes = new byte[] { randomFrom(set) };
final BytesReference corrupt = new BytesArray(corruptBytes);
final IllegalStateException e = expectThrows(IllegalStateException.class, () -> corrupt.streamInput().readBoolean());
final String message = String.format(Locale.ROOT, "unexpected byte [0x%02x]", corruptBytes[0]);
assertThat(e, hasToString(containsString(message)));
}
public void testOptionalBooleanSerialization() throws IOException {
final BytesStreamOutput output = new BytesStreamOutput();
output.writeOptionalBoolean(false);
output.writeOptionalBoolean(true);
output.writeOptionalBoolean(null);
final BytesReference bytesReference = output.bytes();
final BytesRef bytesRef = bytesReference.toBytesRef();
assertThat(bytesRef.length, equalTo(3));
final byte[] bytes = bytesRef.bytes;
assertThat(bytes[0], equalTo((byte) 0));
assertThat(bytes[1], equalTo((byte) 1));
assertThat(bytes[2], equalTo((byte) 2));
final StreamInput input = bytesReference.streamInput();
final Boolean maybeFalse = input.readOptionalBoolean();
assertNotNull(maybeFalse);
assertFalse(maybeFalse);
final Boolean maybeTrue = input.readOptionalBoolean();
assertNotNull(maybeTrue);
assertTrue(maybeTrue);
assertNull(input.readOptionalBoolean());
final Set<Byte> set = IntStream.range(Byte.MIN_VALUE, Byte.MAX_VALUE).mapToObj(v -> (byte) v).collect(Collectors.toSet());
set.remove((byte) 0);
set.remove((byte) 1);
set.remove((byte) 2);
final byte[] corruptBytes = new byte[] { randomFrom(set) };
final BytesReference corrupt = new BytesArray(corruptBytes);
final IllegalStateException e = expectThrows(IllegalStateException.class, () -> corrupt.streamInput().readOptionalBoolean());
final String message = String.format(Locale.ROOT, "unexpected byte [0x%02x]", corruptBytes[0]);
assertThat(e, hasToString(containsString(message)));
}
public void testRandomVLongSerialization() throws IOException { public void testRandomVLongSerialization() throws IOException {
for (int i = 0; i < 1024; i++) { for (int i = 0; i < 1024; i++) {
long write = randomLong(); long write = randomLong();
@ -179,4 +249,5 @@ public class StreamTests extends ESTestCase {
out.writeString(string); out.writeString(string);
} }
} }
} }