Avoid corruption when deserializing booleans
Today we write 0x00 or 0x01 for false or true when serializing a boolean (and 0x02 for null when serializing an optional boolean) but we deserialize any non-zero byte to true (except when deserializing an optional boolean in which case we deserialize 0x02 to null, 0x01 to true, and any other non-zero byte to false). This too easily allows corruption into the stream. Instead, we should mark the stream as corrupted and stop deserializing. This catches when we try to deserialize something as a boolean that is not a boolean. Relates #22152
This commit is contained in:
parent
510ad7b9c7
commit
0d195f1afa
|
@ -59,6 +59,7 @@ import java.util.Date;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.LinkedHashMap;
|
import java.util.LinkedHashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Locale;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.function.IntFunction;
|
import java.util.function.IntFunction;
|
||||||
import java.util.function.Supplier;
|
import java.util.function.Supplier;
|
||||||
|
@ -386,19 +387,28 @@ public abstract class StreamInput extends InputStream {
|
||||||
* Reads a boolean.
|
* Reads a boolean.
|
||||||
*/
|
*/
|
||||||
public final boolean readBoolean() throws IOException {
|
public final boolean readBoolean() throws IOException {
|
||||||
return readByte() != 0;
|
return readBoolean(readByte());
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean readBoolean(final byte value) {
|
||||||
|
if (value == 0) {
|
||||||
|
return false;
|
||||||
|
} else if (value == 1) {
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
final String message = String.format(Locale.ROOT, "unexpected byte [0x%02x]", value);
|
||||||
|
throw new IllegalStateException(message);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Nullable
|
@Nullable
|
||||||
public final Boolean readOptionalBoolean() throws IOException {
|
public final Boolean readOptionalBoolean() throws IOException {
|
||||||
byte val = readByte();
|
final byte value = readByte();
|
||||||
if (val == 2) {
|
if (value == 2) {
|
||||||
return null;
|
return null;
|
||||||
|
} else {
|
||||||
|
return readBoolean(value);
|
||||||
}
|
}
|
||||||
if (val == 1) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -368,7 +368,7 @@ public abstract class StreamOutput extends OutputStream {
|
||||||
if (b == null) {
|
if (b == null) {
|
||||||
writeByte(TWO);
|
writeByte(TWO);
|
||||||
} else {
|
} else {
|
||||||
writeByte(b ? ONE : ZERO);
|
writeBoolean(b);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
|
|
||||||
package org.elasticsearch.common.io.stream;
|
package org.elasticsearch.common.io.stream;
|
||||||
|
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.elasticsearch.common.bytes.BytesArray;
|
import org.elasticsearch.common.bytes.BytesArray;
|
||||||
import org.elasticsearch.common.bytes.BytesReference;
|
import org.elasticsearch.common.bytes.BytesReference;
|
||||||
import org.elasticsearch.common.collect.Tuple;
|
import org.elasticsearch.common.collect.Tuple;
|
||||||
|
@ -31,11 +32,80 @@ import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.LinkedHashMap;
|
import java.util.LinkedHashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Locale;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.containsString;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.hamcrest.Matchers.hasToString;
|
||||||
|
|
||||||
public class StreamTests extends ESTestCase {
|
public class StreamTests extends ESTestCase {
|
||||||
|
|
||||||
|
public void testBooleanSerialization() throws IOException {
|
||||||
|
final BytesStreamOutput output = new BytesStreamOutput();
|
||||||
|
output.writeBoolean(false);
|
||||||
|
output.writeBoolean(true);
|
||||||
|
|
||||||
|
final BytesReference bytesReference = output.bytes();
|
||||||
|
|
||||||
|
final BytesRef bytesRef = bytesReference.toBytesRef();
|
||||||
|
assertThat(bytesRef.length, equalTo(2));
|
||||||
|
final byte[] bytes = bytesRef.bytes;
|
||||||
|
assertThat(bytes[0], equalTo((byte) 0));
|
||||||
|
assertThat(bytes[1], equalTo((byte) 1));
|
||||||
|
|
||||||
|
final StreamInput input = bytesReference.streamInput();
|
||||||
|
assertFalse(input.readBoolean());
|
||||||
|
assertTrue(input.readBoolean());
|
||||||
|
|
||||||
|
final Set<Byte> set = IntStream.range(Byte.MIN_VALUE, Byte.MAX_VALUE).mapToObj(v -> (byte) v).collect(Collectors.toSet());
|
||||||
|
set.remove((byte) 0);
|
||||||
|
set.remove((byte) 1);
|
||||||
|
final byte[] corruptBytes = new byte[] { randomFrom(set) };
|
||||||
|
final BytesReference corrupt = new BytesArray(corruptBytes);
|
||||||
|
final IllegalStateException e = expectThrows(IllegalStateException.class, () -> corrupt.streamInput().readBoolean());
|
||||||
|
final String message = String.format(Locale.ROOT, "unexpected byte [0x%02x]", corruptBytes[0]);
|
||||||
|
assertThat(e, hasToString(containsString(message)));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testOptionalBooleanSerialization() throws IOException {
|
||||||
|
final BytesStreamOutput output = new BytesStreamOutput();
|
||||||
|
output.writeOptionalBoolean(false);
|
||||||
|
output.writeOptionalBoolean(true);
|
||||||
|
output.writeOptionalBoolean(null);
|
||||||
|
|
||||||
|
final BytesReference bytesReference = output.bytes();
|
||||||
|
|
||||||
|
final BytesRef bytesRef = bytesReference.toBytesRef();
|
||||||
|
assertThat(bytesRef.length, equalTo(3));
|
||||||
|
final byte[] bytes = bytesRef.bytes;
|
||||||
|
assertThat(bytes[0], equalTo((byte) 0));
|
||||||
|
assertThat(bytes[1], equalTo((byte) 1));
|
||||||
|
assertThat(bytes[2], equalTo((byte) 2));
|
||||||
|
|
||||||
|
final StreamInput input = bytesReference.streamInput();
|
||||||
|
final Boolean maybeFalse = input.readOptionalBoolean();
|
||||||
|
assertNotNull(maybeFalse);
|
||||||
|
assertFalse(maybeFalse);
|
||||||
|
final Boolean maybeTrue = input.readOptionalBoolean();
|
||||||
|
assertNotNull(maybeTrue);
|
||||||
|
assertTrue(maybeTrue);
|
||||||
|
assertNull(input.readOptionalBoolean());
|
||||||
|
|
||||||
|
final Set<Byte> set = IntStream.range(Byte.MIN_VALUE, Byte.MAX_VALUE).mapToObj(v -> (byte) v).collect(Collectors.toSet());
|
||||||
|
set.remove((byte) 0);
|
||||||
|
set.remove((byte) 1);
|
||||||
|
set.remove((byte) 2);
|
||||||
|
final byte[] corruptBytes = new byte[] { randomFrom(set) };
|
||||||
|
final BytesReference corrupt = new BytesArray(corruptBytes);
|
||||||
|
final IllegalStateException e = expectThrows(IllegalStateException.class, () -> corrupt.streamInput().readOptionalBoolean());
|
||||||
|
final String message = String.format(Locale.ROOT, "unexpected byte [0x%02x]", corruptBytes[0]);
|
||||||
|
assertThat(e, hasToString(containsString(message)));
|
||||||
|
}
|
||||||
|
|
||||||
public void testRandomVLongSerialization() throws IOException {
|
public void testRandomVLongSerialization() throws IOException {
|
||||||
for (int i = 0; i < 1024; i++) {
|
for (int i = 0; i < 1024; i++) {
|
||||||
long write = randomLong();
|
long write = randomLong();
|
||||||
|
@ -179,4 +249,5 @@ public class StreamTests extends ESTestCase {
|
||||||
out.writeString(string);
|
out.writeString(string);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue