Throw an exception if Writeable.Reader reads null

If a Writeable.Reader returns null it is always a bug, probably one that
will cause corruption in the StreamInput it was trying to read from. This
commit adds a check that attempts to catch these errors quickly including
the name of the reader.
This commit is contained in:
Nik Everett 2016-03-24 12:26:33 -04:00
parent 6dd164d0bd
commit 5e8656aff0
4 changed files with 38 additions and 4 deletions

View File

@ -36,6 +36,12 @@ public class NamedWriteableAwareStreamInput extends FilterStreamInput {
@Override
<C> C readNamedWriteable(Class<C> categoryClass) throws IOException {
String name = readString();
return namedWriteableRegistry.getReader(categoryClass, name).read(this);
Writeable.Reader<? extends C> reader = namedWriteableRegistry.getReader(categoryClass, name);
C c = reader.read(this);
if (c == null) {
throw new IOException(
"Writeable.Reader [" + reader + "] returned null which is not allowed and probably means it screwed up the stream.");
}
return c;
}
}

View File

@ -566,9 +566,14 @@ public abstract class StreamInput extends InputStream {
}
}
public <T extends Writeable> T readOptionalWriteable(Writeable.Reader<T> provider) throws IOException {
public <T extends Writeable> T readOptionalWriteable(Writeable.Reader<T> reader) throws IOException {
if (readBoolean()) {
return provider.read(this);
T t = reader.read(this);
if (t == null) {
throw new IOException("Writeable.Reader [" + reader
+ "] returned null which is not allowed and probably means it screwed up the stream.");
}
return t;
} else {
return null;
}

View File

@ -51,7 +51,8 @@ public interface Writeable<T> extends StreamableReader<T> { // TODO remove exten
/**
* Reference to a method that can read some object from a stream. By convention this is a constructor that takes
* {@linkplain StreamInput} as an argument for most classes and a static method for things like enums.
* {@linkplain StreamInput} as an argument for most classes and a static method for things like enums. Returning null from one of these
* is always wrong - for that we use methods like {@link StreamInput#readOptionalWriteable(Reader)}.
*/
@FunctionalInterface
interface Reader<R> {

View File

@ -29,6 +29,7 @@ import java.io.IOException;
import java.util.Objects;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.endsWith;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.startsWith;
@ -373,6 +374,27 @@ public class BytesStreamsTests extends ESTestCase {
}
}
public void testNamedWriteableReaderReturnsNull() throws IOException {
BytesStreamOutput out = new BytesStreamOutput();
NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry();
namedWriteableRegistry.register(BaseNamedWriteable.class, TestNamedWriteable.NAME, (StreamInput in) -> null);
TestNamedWriteable namedWriteableIn = new TestNamedWriteable(randomAsciiOfLengthBetween(1, 10), randomAsciiOfLengthBetween(1, 10));
out.writeNamedWriteable(namedWriteableIn);
byte[] bytes = out.bytes().toBytes();
StreamInput in = new NamedWriteableAwareStreamInput(StreamInput.wrap(bytes), namedWriteableRegistry);
assertEquals(in.available(), bytes.length);
IOException e = expectThrows(IOException.class, () -> in.readNamedWriteable(BaseNamedWriteable.class));
assertThat(e.getMessage(), endsWith("] returned null which is not allowed and probably means it screwed up the stream."));
}
public void testOptionalWriteableReaderReturnsNull() throws IOException {
BytesStreamOutput out = new BytesStreamOutput();
out.writeOptionalWriteable(new TestNamedWriteable(randomAsciiOfLengthBetween(1, 10), randomAsciiOfLengthBetween(1, 10)));
StreamInput in = StreamInput.wrap(out.bytes().toBytes());
IOException e = expectThrows(IOException.class, () -> in.readOptionalWriteable((StreamInput ignored) -> null));
assertThat(e.getMessage(), endsWith("] returned null which is not allowed and probably means it screwed up the stream."));
}
private static abstract class BaseNamedWriteable<T> implements NamedWriteable<T> {
}