[BUG] Serialization bugs can cause node drops (#1885)

This commit restructures InboundHandler to ensure all data 
is consumed over the wire.

Signed-off-by: Andriy Redko <andriy.redko@aiven.io>
This commit is contained in:
Andriy Redko 2022-01-14 14:02:34 -05:00 committed by GitHub
parent e7d44c20e9
commit f059738aac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 430 additions and 36 deletions

View File

@ -47,6 +47,7 @@ import org.opensearch.common.util.concurrent.AbstractRunnable;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.threadpool.ThreadPool;
import java.io.EOFException;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
@ -149,27 +150,13 @@ public class InboundHandler {
streamInput = namedWriteableStream(message.openOrGetStreamInput());
assertRemoteVersion(streamInput, header.getVersion());
if (header.isError()) {
handlerResponseError(streamInput, handler);
handlerResponseError(requestId, streamInput, handler);
} else {
handleResponse(remoteAddress, streamInput, handler);
}
// Check the entire message has been read
final int nextByte = streamInput.read();
// calling read() is useful to make sure the message is fully read, even if there is an EOS marker
if (nextByte != -1) {
throw new IllegalStateException(
"Message not fully read (response) for requestId ["
+ requestId
+ "], handler ["
+ handler
+ "], error ["
+ header.isError()
+ "]; resetting"
);
handleResponse(requestId, remoteAddress, streamInput, handler);
}
} else {
assert header.isError() == false;
handleResponse(remoteAddress, EMPTY_STREAM_INPUT, handler);
handleResponse(requestId, remoteAddress, EMPTY_STREAM_INPUT, handler);
}
}
}
@ -246,22 +233,11 @@ public class InboundHandler {
assertRemoteVersion(stream, header.getVersion());
final RequestHandlerRegistry<T> reg = requestHandlers.getHandler(action);
assert reg != null;
final T request = reg.newRequest(stream);
final T request = newRequest(requestId, action, stream, reg);
request.remoteAddress(new TransportAddress(channel.getRemoteAddress()));
// in case we throw an exception, i.e. when the limit is hit, we don't want to verify
final int nextByte = stream.read();
// calling read() is useful to make sure the message is fully read, even if there some kind of EOS marker
if (nextByte != -1) {
throw new IllegalStateException(
"Message not fully read (request) for requestId ["
+ requestId
+ "], action ["
+ action
+ "], available ["
+ stream.available()
+ "]; resetting"
);
}
checkStreamIsFullyConsumed(requestId, action, stream);
final String executor = reg.getExecutor();
if (ThreadPool.Names.SAME.equals(executor)) {
try {
@ -279,6 +255,97 @@ public class InboundHandler {
}
}
/**
* Creates new request instance out of input stream. Throws IllegalStateException if the end of
* the stream was reached before the request is fully deserialized from the stream.
* @param <T> transport request type
* @param requestId request identifier
* @param action action name
* @param stream stream
* @param reg request handler registry
* @return new request instance
* @throws IOException IOException
* @throws IllegalStateException IllegalStateException
*/
private <T extends TransportRequest> T newRequest(
final long requestId,
final String action,
final StreamInput stream,
final RequestHandlerRegistry<T> reg
) throws IOException {
try {
return reg.newRequest(stream);
} catch (final EOFException e) {
// Another favor of (de)serialization issues is when stream contains less bytes than
// the request handler needs to deserialize the payload.
throw new IllegalStateException(
"Message fully read (request) but more data is expected for requestId ["
+ requestId
+ "], action ["
+ action
+ "]; resetting",
e
);
}
}
/**
* Checks if the stream is fully consumed and throws the exceptions if that is not the case.
* @param requestId request identifier
* @param action action name
* @param stream stream
* @throws IOException IOException
*/
private void checkStreamIsFullyConsumed(final long requestId, final String action, final StreamInput stream) throws IOException {
// in case we throw an exception, i.e. when the limit is hit, we don't want to verify
final int nextByte = stream.read();
// calling read() is useful to make sure the message is fully read, even if there some kind of EOS marker
if (nextByte != -1) {
throw new IllegalStateException(
"Message not fully read (request) for requestId ["
+ requestId
+ "], action ["
+ action
+ "], available ["
+ stream.available()
+ "]; resetting"
);
}
}
/**
* Checks if the stream is fully consumed and throws the exceptions if that is not the case.
* @param requestId request identifier
* @param handler response handler
* @param stream stream
* @param error "true" if response represents error, "false" otherwise
* @throws IOException IOException
*/
private void checkStreamIsFullyConsumed(
final long requestId,
final TransportResponseHandler<?> handler,
final StreamInput stream,
final boolean error
) throws IOException {
if (stream != EMPTY_STREAM_INPUT) {
// Check the entire message has been read
final int nextByte = stream.read();
// calling read() is useful to make sure the message is fully read, even if there is an EOS marker
if (nextByte != -1) {
throw new IllegalStateException(
"Message not fully read (response) for requestId ["
+ requestId
+ "], handler ["
+ handler
+ "], error ["
+ error
+ "]; resetting"
);
}
}
}
private static void sendErrorResponse(String actionName, TransportChannel transportChannel, Exception e) {
try {
transportChannel.sendResponse(e);
@ -289,6 +356,7 @@ public class InboundHandler {
}
private <T extends TransportResponse> void handleResponse(
final long requestId,
InetSocketAddress remoteAddress,
final StreamInput stream,
final TransportResponseHandler<T> handler
@ -297,6 +365,7 @@ public class InboundHandler {
try {
response = handler.read(stream);
response.remoteAddress(new TransportAddress(remoteAddress));
checkStreamIsFullyConsumed(requestId, handler, stream, false);
} catch (Exception e) {
final Exception serializationException = new TransportSerializationException(
"Failed to deserialize response from handler [" + handler + "]",
@ -322,10 +391,11 @@ public class InboundHandler {
}
}
private void handlerResponseError(StreamInput stream, final TransportResponseHandler<?> handler) {
private void handlerResponseError(final long requestId, StreamInput stream, final TransportResponseHandler<?> handler) {
Exception error;
try {
error = stream.readException();
checkStreamIsFullyConsumed(requestId, handler, stream, true);
} catch (Exception e) {
error = new TransportSerializationException(
"Failed to deserialize exception response from stream for handler [" + handler + "]",

View File

@ -34,6 +34,7 @@ package org.opensearch.transport;
import org.apache.logging.log4j.Level;
import org.apache.logging.log4j.LogManager;
import org.apache.lucene.util.BytesRef;
import org.opensearch.OpenSearchException;
import org.opensearch.Version;
import org.opensearch.action.ActionListener;
@ -57,13 +58,17 @@ import org.opensearch.threadpool.ThreadPool;
import org.junit.After;
import org.junit.Before;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.Collections;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.startsWith;
import static org.hamcrest.Matchers.instanceOf;
public class InboundHandlerTests extends OpenSearchTestCase {
@ -75,16 +80,24 @@ public class InboundHandlerTests extends OpenSearchTestCase {
private Transport.ResponseHandlers responseHandlers;
private Transport.RequestHandlers requestHandlers;
private InboundHandler handler;
private OutboundHandler outboundHandler;
private FakeTcpChannel channel;
@Before
public void setUp() throws Exception {
super.setUp();
taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet());
channel = new FakeTcpChannel(randomBoolean(), buildNewFakeTransportAddress().address(), buildNewFakeTransportAddress().address());
channel = new FakeTcpChannel(randomBoolean(), buildNewFakeTransportAddress().address(), buildNewFakeTransportAddress().address()) {
public void sendMessage(BytesReference reference, org.opensearch.action.ActionListener<Void> listener) {
super.sendMessage(reference, listener);
if (listener != null) {
listener.onResponse(null);
}
}
};
NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
TransportHandshaker handshaker = new TransportHandshaker(version, threadPool, (n, c, r, v) -> {});
OutboundHandler outboundHandler = new OutboundHandler(
outboundHandler = new OutboundHandler(
"node",
version,
new String[0],
@ -211,7 +224,7 @@ public class InboundHandlerTests extends OpenSearchTestCase {
BytesReference fullResponseBytes = channel.getMessageCaptor().get();
BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize);
Header responseHeader = new Header(fullRequestBytes.length() - 6, requestId, responseStatus, version);
Header responseHeader = new Header(fullResponseBytes.length() - 6, requestId, responseStatus, version);
InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {});
responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput());
handler.inboundMessage(channel, responseMessage);
@ -326,6 +339,317 @@ public class InboundHandlerTests extends OpenSearchTestCase {
}
}
public void testRequestNotFullyRead() throws Exception {
String action = "test-request";
int headerSize = TcpHeader.headerSize(version);
AtomicReference<Exception> exceptionCaptor = new AtomicReference<>();
long requestId = responseHandlers.add(new Transport.ResponseContext<>(new TransportResponseHandler<TestResponse>() {
@Override
public void handleResponse(TestResponse response) {}
@Override
public void handleException(TransportException exp) {
exceptionCaptor.set(exp);
}
@Override
public String executor() {
return ThreadPool.Names.SAME;
}
@Override
public TestResponse read(StreamInput in) throws IOException {
return new TestResponse(in);
}
}, null, action));
RequestHandlerRegistry<TestRequest> registry = new RequestHandlerRegistry<>(
action,
TestRequest::new,
taskManager,
(request, channel, task) -> {},
ThreadPool.Names.SAME,
false,
true
);
requestHandlers.registerHandler(registry);
String requestValue = randomAlphaOfLength(10);
OutboundMessage.Request request = new OutboundMessage.Request(
threadPool.getThreadContext(),
new String[0],
new TestRequest(requestValue),
version,
action,
requestId,
false,
false
);
outboundHandler.setMessageListener(new TransportMessageListener() {
@Override
public void onResponseSent(long requestId, String action, Exception error) {
exceptionCaptor.set(error);
}
});
// Create the request payload with 1 byte overflow
final BytesRef bytes = request.serialize(new BytesStreamOutput()).toBytesRef();
final ByteBuffer buffer = ByteBuffer.allocate(bytes.length + 1);
buffer.put(bytes.bytes, 0, bytes.length);
buffer.put((byte) 1);
BytesReference fullRequestBytes = BytesReference.fromByteBuffer((ByteBuffer) buffer.flip());
BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize);
Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version);
InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {});
requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput());
handler.inboundMessage(channel, requestMessage);
assertThat(exceptionCaptor.get(), instanceOf(IllegalStateException.class));
assertThat(exceptionCaptor.get().getMessage(), startsWith("Message not fully read (request) for requestId"));
}
public void testRequestFullyReadButMoreDataIsAvailable() throws Exception {
String action = "test-request";
int headerSize = TcpHeader.headerSize(version);
AtomicReference<Exception> exceptionCaptor = new AtomicReference<>();
long requestId = responseHandlers.add(new Transport.ResponseContext<>(new TransportResponseHandler<TestResponse>() {
@Override
public void handleResponse(TestResponse response) {}
@Override
public void handleException(TransportException exp) {
exceptionCaptor.set(exp);
}
@Override
public String executor() {
return ThreadPool.Names.SAME;
}
@Override
public TestResponse read(StreamInput in) throws IOException {
return new TestResponse(in);
}
}, null, action));
RequestHandlerRegistry<TestRequest> registry = new RequestHandlerRegistry<>(
action,
TestRequest::new,
taskManager,
(request, channel, task) -> {},
ThreadPool.Names.SAME,
false,
true
);
requestHandlers.registerHandler(registry);
String requestValue = randomAlphaOfLength(10);
OutboundMessage.Request request = new OutboundMessage.Request(
threadPool.getThreadContext(),
new String[0],
new TestRequest(requestValue),
version,
action,
requestId,
false,
false
);
outboundHandler.setMessageListener(new TransportMessageListener() {
@Override
public void onResponseSent(long requestId, String action, Exception error) {
exceptionCaptor.set(error);
}
});
final BytesReference fullRequestBytes = request.serialize(new BytesStreamOutput());
// Create the request payload by intentionally stripping 1 byte away
BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize - 1);
Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version);
InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {});
requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput());
handler.inboundMessage(channel, requestMessage);
assertThat(exceptionCaptor.get(), instanceOf(IllegalStateException.class));
assertThat(exceptionCaptor.get().getCause(), instanceOf(EOFException.class));
assertThat(exceptionCaptor.get().getMessage(), startsWith("Message fully read (request) but more data is expected for requestId"));
}
public void testResponseNotFullyRead() throws Exception {
String action = "test-request";
int headerSize = TcpHeader.headerSize(version);
AtomicReference<TestRequest> requestCaptor = new AtomicReference<>();
AtomicReference<Exception> exceptionCaptor = new AtomicReference<>();
AtomicReference<TestResponse> responseCaptor = new AtomicReference<>();
AtomicReference<TransportChannel> channelCaptor = new AtomicReference<>();
long requestId = responseHandlers.add(new Transport.ResponseContext<>(new TransportResponseHandler<TestResponse>() {
@Override
public void handleResponse(TestResponse response) {
responseCaptor.set(response);
}
@Override
public void handleException(TransportException exp) {
exceptionCaptor.set(exp);
}
@Override
public String executor() {
return ThreadPool.Names.SAME;
}
@Override
public TestResponse read(StreamInput in) throws IOException {
return new TestResponse(in);
}
}, null, action));
RequestHandlerRegistry<TestRequest> registry = new RequestHandlerRegistry<>(
action,
TestRequest::new,
taskManager,
(request, channel, task) -> {
channelCaptor.set(channel);
requestCaptor.set(request);
},
ThreadPool.Names.SAME,
false,
true
);
requestHandlers.registerHandler(registry);
String requestValue = randomAlphaOfLength(10);
OutboundMessage.Request request = new OutboundMessage.Request(
threadPool.getThreadContext(),
new String[0],
new TestRequest(requestValue),
version,
action,
requestId,
false,
false
);
BytesReference fullRequestBytes = request.serialize(new BytesStreamOutput());
BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize);
Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version);
InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {});
requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput());
handler.inboundMessage(channel, requestMessage);
TransportChannel transportChannel = channelCaptor.get();
assertEquals(Version.CURRENT, transportChannel.getVersion());
assertEquals("transport", transportChannel.getChannelType());
assertEquals(requestValue, requestCaptor.get().value);
String responseValue = randomAlphaOfLength(10);
byte responseStatus = TransportStatus.setResponse((byte) 0);
transportChannel.sendResponse(new TestResponse(responseValue));
// Create the response payload with 1 byte overflow
final BytesRef bytes = channel.getMessageCaptor().get().toBytesRef();
final ByteBuffer buffer = ByteBuffer.allocate(bytes.length + 1);
buffer.put(bytes.bytes, 0, bytes.length);
buffer.put((byte) 1);
BytesReference fullResponseBytes = BytesReference.fromByteBuffer((ByteBuffer) buffer.flip());
BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize);
Header responseHeader = new Header(fullResponseBytes.length() - 6, requestId, responseStatus, version);
InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {});
responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput());
handler.inboundMessage(channel, responseMessage);
assertThat(exceptionCaptor.get(), instanceOf(RemoteTransportException.class));
assertThat(exceptionCaptor.get().getCause(), instanceOf(TransportSerializationException.class));
assertThat(exceptionCaptor.get().getMessage(), containsString("Failed to deserialize response from handler"));
}
public void testResponseFullyReadButMoreDataIsAvailable() throws Exception {
String action = "test-request";
int headerSize = TcpHeader.headerSize(version);
AtomicReference<TestRequest> requestCaptor = new AtomicReference<>();
AtomicReference<Exception> exceptionCaptor = new AtomicReference<>();
AtomicReference<TestResponse> responseCaptor = new AtomicReference<>();
AtomicReference<TransportChannel> channelCaptor = new AtomicReference<>();
long requestId = responseHandlers.add(new Transport.ResponseContext<>(new TransportResponseHandler<TestResponse>() {
@Override
public void handleResponse(TestResponse response) {
responseCaptor.set(response);
}
@Override
public void handleException(TransportException exp) {
exceptionCaptor.set(exp);
}
@Override
public String executor() {
return ThreadPool.Names.SAME;
}
@Override
public TestResponse read(StreamInput in) throws IOException {
return new TestResponse(in);
}
}, null, action));
RequestHandlerRegistry<TestRequest> registry = new RequestHandlerRegistry<>(
action,
TestRequest::new,
taskManager,
(request, channel, task) -> {
channelCaptor.set(channel);
requestCaptor.set(request);
},
ThreadPool.Names.SAME,
false,
true
);
requestHandlers.registerHandler(registry);
String requestValue = randomAlphaOfLength(10);
OutboundMessage.Request request = new OutboundMessage.Request(
threadPool.getThreadContext(),
new String[0],
new TestRequest(requestValue),
version,
action,
requestId,
false,
false
);
BytesReference fullRequestBytes = request.serialize(new BytesStreamOutput());
BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize);
Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version);
InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {});
requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput());
handler.inboundMessage(channel, requestMessage);
TransportChannel transportChannel = channelCaptor.get();
assertEquals(Version.CURRENT, transportChannel.getVersion());
assertEquals("transport", transportChannel.getChannelType());
assertEquals(requestValue, requestCaptor.get().value);
String responseValue = randomAlphaOfLength(10);
byte responseStatus = TransportStatus.setResponse((byte) 0);
transportChannel.sendResponse(new TestResponse(responseValue));
BytesReference fullResponseBytes = channel.getMessageCaptor().get();
// Create the response payload by intentionally stripping 1 byte away
BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize - 1);
Header responseHeader = new Header(fullResponseBytes.length() - 6, requestId, responseStatus, version);
InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {});
responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput());
handler.inboundMessage(channel, responseMessage);
assertThat(exceptionCaptor.get(), instanceOf(RemoteTransportException.class));
assertThat(exceptionCaptor.get().getCause(), instanceOf(TransportSerializationException.class));
assertThat(exceptionCaptor.get().getMessage(), containsString("Failed to deserialize response from handler"));
}
private static InboundMessage unreadableInboundHandshake(Version remoteVersion, Header requestHeader) {
return new InboundMessage(requestHeader, ReleasableBytesReference.wrap(BytesArray.EMPTY), () -> {}) {
@Override