diff --git a/server/src/main/java/org/opensearch/transport/InboundHandler.java b/server/src/main/java/org/opensearch/transport/InboundHandler.java index bf4ada27900..6aa319934b4 100644 --- a/server/src/main/java/org/opensearch/transport/InboundHandler.java +++ b/server/src/main/java/org/opensearch/transport/InboundHandler.java @@ -47,6 +47,7 @@ import org.opensearch.common.util.concurrent.AbstractRunnable; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.threadpool.ThreadPool; +import java.io.EOFException; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; @@ -149,27 +150,13 @@ public class InboundHandler { streamInput = namedWriteableStream(message.openOrGetStreamInput()); assertRemoteVersion(streamInput, header.getVersion()); if (header.isError()) { - handlerResponseError(streamInput, handler); + handlerResponseError(requestId, streamInput, handler); } else { - handleResponse(remoteAddress, streamInput, handler); - } - // Check the entire message has been read - final int nextByte = streamInput.read(); - // calling read() is useful to make sure the message is fully read, even if there is an EOS marker - if (nextByte != -1) { - throw new IllegalStateException( - "Message not fully read (response) for requestId [" - + requestId - + "], handler [" - + handler - + "], error [" - + header.isError() - + "]; resetting" - ); + handleResponse(requestId, remoteAddress, streamInput, handler); } } else { assert header.isError() == false; - handleResponse(remoteAddress, EMPTY_STREAM_INPUT, handler); + handleResponse(requestId, remoteAddress, EMPTY_STREAM_INPUT, handler); } } } @@ -246,22 +233,11 @@ public class InboundHandler { assertRemoteVersion(stream, header.getVersion()); final RequestHandlerRegistry reg = requestHandlers.getHandler(action); assert reg != null; - final T request = reg.newRequest(stream); + + final T request = newRequest(requestId, action, stream, reg); request.remoteAddress(new TransportAddress(channel.getRemoteAddress())); - // in case we throw an exception, i.e. when the limit is hit, we don't want to verify - final int nextByte = stream.read(); - // calling read() is useful to make sure the message is fully read, even if there some kind of EOS marker - if (nextByte != -1) { - throw new IllegalStateException( - "Message not fully read (request) for requestId [" - + requestId - + "], action [" - + action - + "], available [" - + stream.available() - + "]; resetting" - ); - } + checkStreamIsFullyConsumed(requestId, action, stream); + final String executor = reg.getExecutor(); if (ThreadPool.Names.SAME.equals(executor)) { try { @@ -279,6 +255,97 @@ public class InboundHandler { } } + /** + * Creates new request instance out of input stream. Throws IllegalStateException if the end of + * the stream was reached before the request is fully deserialized from the stream. + * @param transport request type + * @param requestId request identifier + * @param action action name + * @param stream stream + * @param reg request handler registry + * @return new request instance + * @throws IOException IOException + * @throws IllegalStateException IllegalStateException + */ + private T newRequest( + final long requestId, + final String action, + final StreamInput stream, + final RequestHandlerRegistry reg + ) throws IOException { + try { + return reg.newRequest(stream); + } catch (final EOFException e) { + // Another favor of (de)serialization issues is when stream contains less bytes than + // the request handler needs to deserialize the payload. + throw new IllegalStateException( + "Message fully read (request) but more data is expected for requestId [" + + requestId + + "], action [" + + action + + "]; resetting", + e + ); + } + } + + /** + * Checks if the stream is fully consumed and throws the exceptions if that is not the case. + * @param requestId request identifier + * @param action action name + * @param stream stream + * @throws IOException IOException + */ + private void checkStreamIsFullyConsumed(final long requestId, final String action, final StreamInput stream) throws IOException { + // in case we throw an exception, i.e. when the limit is hit, we don't want to verify + final int nextByte = stream.read(); + + // calling read() is useful to make sure the message is fully read, even if there some kind of EOS marker + if (nextByte != -1) { + throw new IllegalStateException( + "Message not fully read (request) for requestId [" + + requestId + + "], action [" + + action + + "], available [" + + stream.available() + + "]; resetting" + ); + } + } + + /** + * Checks if the stream is fully consumed and throws the exceptions if that is not the case. + * @param requestId request identifier + * @param handler response handler + * @param stream stream + * @param error "true" if response represents error, "false" otherwise + * @throws IOException IOException + */ + private void checkStreamIsFullyConsumed( + final long requestId, + final TransportResponseHandler handler, + final StreamInput stream, + final boolean error + ) throws IOException { + if (stream != EMPTY_STREAM_INPUT) { + // Check the entire message has been read + final int nextByte = stream.read(); + // calling read() is useful to make sure the message is fully read, even if there is an EOS marker + if (nextByte != -1) { + throw new IllegalStateException( + "Message not fully read (response) for requestId [" + + requestId + + "], handler [" + + handler + + "], error [" + + error + + "]; resetting" + ); + } + } + } + private static void sendErrorResponse(String actionName, TransportChannel transportChannel, Exception e) { try { transportChannel.sendResponse(e); @@ -289,6 +356,7 @@ public class InboundHandler { } private void handleResponse( + final long requestId, InetSocketAddress remoteAddress, final StreamInput stream, final TransportResponseHandler handler @@ -297,6 +365,7 @@ public class InboundHandler { try { response = handler.read(stream); response.remoteAddress(new TransportAddress(remoteAddress)); + checkStreamIsFullyConsumed(requestId, handler, stream, false); } catch (Exception e) { final Exception serializationException = new TransportSerializationException( "Failed to deserialize response from handler [" + handler + "]", @@ -322,10 +391,11 @@ public class InboundHandler { } } - private void handlerResponseError(StreamInput stream, final TransportResponseHandler handler) { + private void handlerResponseError(final long requestId, StreamInput stream, final TransportResponseHandler handler) { Exception error; try { error = stream.readException(); + checkStreamIsFullyConsumed(requestId, handler, stream, true); } catch (Exception e) { error = new TransportSerializationException( "Failed to deserialize exception response from stream for handler [" + handler + "]", diff --git a/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java index 882a783b667..4076e7229eb 100644 --- a/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java @@ -34,6 +34,7 @@ package org.opensearch.transport; import org.apache.logging.log4j.Level; import org.apache.logging.log4j.LogManager; +import org.apache.lucene.util.BytesRef; import org.opensearch.OpenSearchException; import org.opensearch.Version; import org.opensearch.action.ActionListener; @@ -57,13 +58,17 @@ import org.opensearch.threadpool.ThreadPool; import org.junit.After; import org.junit.Before; +import java.io.EOFException; import java.io.IOException; import java.io.InputStream; +import java.nio.ByteBuffer; import java.util.Collections; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.startsWith; import static org.hamcrest.Matchers.instanceOf; public class InboundHandlerTests extends OpenSearchTestCase { @@ -75,16 +80,24 @@ public class InboundHandlerTests extends OpenSearchTestCase { private Transport.ResponseHandlers responseHandlers; private Transport.RequestHandlers requestHandlers; private InboundHandler handler; + private OutboundHandler outboundHandler; private FakeTcpChannel channel; @Before public void setUp() throws Exception { super.setUp(); taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet()); - channel = new FakeTcpChannel(randomBoolean(), buildNewFakeTransportAddress().address(), buildNewFakeTransportAddress().address()); + channel = new FakeTcpChannel(randomBoolean(), buildNewFakeTransportAddress().address(), buildNewFakeTransportAddress().address()) { + public void sendMessage(BytesReference reference, org.opensearch.action.ActionListener listener) { + super.sendMessage(reference, listener); + if (listener != null) { + listener.onResponse(null); + } + } + }; NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); TransportHandshaker handshaker = new TransportHandshaker(version, threadPool, (n, c, r, v) -> {}); - OutboundHandler outboundHandler = new OutboundHandler( + outboundHandler = new OutboundHandler( "node", version, new String[0], @@ -211,7 +224,7 @@ public class InboundHandlerTests extends OpenSearchTestCase { BytesReference fullResponseBytes = channel.getMessageCaptor().get(); BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize); - Header responseHeader = new Header(fullRequestBytes.length() - 6, requestId, responseStatus, version); + Header responseHeader = new Header(fullResponseBytes.length() - 6, requestId, responseStatus, version); InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {}); responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput()); handler.inboundMessage(channel, responseMessage); @@ -326,6 +339,317 @@ public class InboundHandlerTests extends OpenSearchTestCase { } } + public void testRequestNotFullyRead() throws Exception { + String action = "test-request"; + int headerSize = TcpHeader.headerSize(version); + AtomicReference exceptionCaptor = new AtomicReference<>(); + + long requestId = responseHandlers.add(new Transport.ResponseContext<>(new TransportResponseHandler() { + @Override + public void handleResponse(TestResponse response) {} + + @Override + public void handleException(TransportException exp) { + exceptionCaptor.set(exp); + } + + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + + @Override + public TestResponse read(StreamInput in) throws IOException { + return new TestResponse(in); + } + }, null, action)); + + RequestHandlerRegistry registry = new RequestHandlerRegistry<>( + action, + TestRequest::new, + taskManager, + (request, channel, task) -> {}, + ThreadPool.Names.SAME, + false, + true + ); + + requestHandlers.registerHandler(registry); + String requestValue = randomAlphaOfLength(10); + OutboundMessage.Request request = new OutboundMessage.Request( + threadPool.getThreadContext(), + new String[0], + new TestRequest(requestValue), + version, + action, + requestId, + false, + false + ); + + outboundHandler.setMessageListener(new TransportMessageListener() { + @Override + public void onResponseSent(long requestId, String action, Exception error) { + exceptionCaptor.set(error); + } + }); + + // Create the request payload with 1 byte overflow + final BytesRef bytes = request.serialize(new BytesStreamOutput()).toBytesRef(); + final ByteBuffer buffer = ByteBuffer.allocate(bytes.length + 1); + buffer.put(bytes.bytes, 0, bytes.length); + buffer.put((byte) 1); + + BytesReference fullRequestBytes = BytesReference.fromByteBuffer((ByteBuffer) buffer.flip()); + BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); + Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); + InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); + requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); + handler.inboundMessage(channel, requestMessage); + + assertThat(exceptionCaptor.get(), instanceOf(IllegalStateException.class)); + assertThat(exceptionCaptor.get().getMessage(), startsWith("Message not fully read (request) for requestId")); + } + + public void testRequestFullyReadButMoreDataIsAvailable() throws Exception { + String action = "test-request"; + int headerSize = TcpHeader.headerSize(version); + AtomicReference exceptionCaptor = new AtomicReference<>(); + + long requestId = responseHandlers.add(new Transport.ResponseContext<>(new TransportResponseHandler() { + @Override + public void handleResponse(TestResponse response) {} + + @Override + public void handleException(TransportException exp) { + exceptionCaptor.set(exp); + } + + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + + @Override + public TestResponse read(StreamInput in) throws IOException { + return new TestResponse(in); + } + }, null, action)); + + RequestHandlerRegistry registry = new RequestHandlerRegistry<>( + action, + TestRequest::new, + taskManager, + (request, channel, task) -> {}, + ThreadPool.Names.SAME, + false, + true + ); + + requestHandlers.registerHandler(registry); + String requestValue = randomAlphaOfLength(10); + OutboundMessage.Request request = new OutboundMessage.Request( + threadPool.getThreadContext(), + new String[0], + new TestRequest(requestValue), + version, + action, + requestId, + false, + false + ); + + outboundHandler.setMessageListener(new TransportMessageListener() { + @Override + public void onResponseSent(long requestId, String action, Exception error) { + exceptionCaptor.set(error); + } + }); + + final BytesReference fullRequestBytes = request.serialize(new BytesStreamOutput()); + // Create the request payload by intentionally stripping 1 byte away + BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize - 1); + Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); + InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); + requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); + handler.inboundMessage(channel, requestMessage); + + assertThat(exceptionCaptor.get(), instanceOf(IllegalStateException.class)); + assertThat(exceptionCaptor.get().getCause(), instanceOf(EOFException.class)); + assertThat(exceptionCaptor.get().getMessage(), startsWith("Message fully read (request) but more data is expected for requestId")); + } + + public void testResponseNotFullyRead() throws Exception { + String action = "test-request"; + int headerSize = TcpHeader.headerSize(version); + AtomicReference requestCaptor = new AtomicReference<>(); + AtomicReference exceptionCaptor = new AtomicReference<>(); + AtomicReference responseCaptor = new AtomicReference<>(); + AtomicReference channelCaptor = new AtomicReference<>(); + + long requestId = responseHandlers.add(new Transport.ResponseContext<>(new TransportResponseHandler() { + @Override + public void handleResponse(TestResponse response) { + responseCaptor.set(response); + } + + @Override + public void handleException(TransportException exp) { + exceptionCaptor.set(exp); + } + + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + + @Override + public TestResponse read(StreamInput in) throws IOException { + return new TestResponse(in); + } + }, null, action)); + RequestHandlerRegistry registry = new RequestHandlerRegistry<>( + action, + TestRequest::new, + taskManager, + (request, channel, task) -> { + channelCaptor.set(channel); + requestCaptor.set(request); + }, + ThreadPool.Names.SAME, + false, + true + ); + requestHandlers.registerHandler(registry); + String requestValue = randomAlphaOfLength(10); + OutboundMessage.Request request = new OutboundMessage.Request( + threadPool.getThreadContext(), + new String[0], + new TestRequest(requestValue), + version, + action, + requestId, + false, + false + ); + + BytesReference fullRequestBytes = request.serialize(new BytesStreamOutput()); + BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); + Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); + InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); + requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); + handler.inboundMessage(channel, requestMessage); + + TransportChannel transportChannel = channelCaptor.get(); + assertEquals(Version.CURRENT, transportChannel.getVersion()); + assertEquals("transport", transportChannel.getChannelType()); + assertEquals(requestValue, requestCaptor.get().value); + + String responseValue = randomAlphaOfLength(10); + byte responseStatus = TransportStatus.setResponse((byte) 0); + transportChannel.sendResponse(new TestResponse(responseValue)); + + // Create the response payload with 1 byte overflow + final BytesRef bytes = channel.getMessageCaptor().get().toBytesRef(); + final ByteBuffer buffer = ByteBuffer.allocate(bytes.length + 1); + buffer.put(bytes.bytes, 0, bytes.length); + buffer.put((byte) 1); + + BytesReference fullResponseBytes = BytesReference.fromByteBuffer((ByteBuffer) buffer.flip()); + BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize); + Header responseHeader = new Header(fullResponseBytes.length() - 6, requestId, responseStatus, version); + InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {}); + responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput()); + handler.inboundMessage(channel, responseMessage); + + assertThat(exceptionCaptor.get(), instanceOf(RemoteTransportException.class)); + assertThat(exceptionCaptor.get().getCause(), instanceOf(TransportSerializationException.class)); + assertThat(exceptionCaptor.get().getMessage(), containsString("Failed to deserialize response from handler")); + } + + public void testResponseFullyReadButMoreDataIsAvailable() throws Exception { + String action = "test-request"; + int headerSize = TcpHeader.headerSize(version); + AtomicReference requestCaptor = new AtomicReference<>(); + AtomicReference exceptionCaptor = new AtomicReference<>(); + AtomicReference responseCaptor = new AtomicReference<>(); + AtomicReference channelCaptor = new AtomicReference<>(); + + long requestId = responseHandlers.add(new Transport.ResponseContext<>(new TransportResponseHandler() { + @Override + public void handleResponse(TestResponse response) { + responseCaptor.set(response); + } + + @Override + public void handleException(TransportException exp) { + exceptionCaptor.set(exp); + } + + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + + @Override + public TestResponse read(StreamInput in) throws IOException { + return new TestResponse(in); + } + }, null, action)); + RequestHandlerRegistry registry = new RequestHandlerRegistry<>( + action, + TestRequest::new, + taskManager, + (request, channel, task) -> { + channelCaptor.set(channel); + requestCaptor.set(request); + }, + ThreadPool.Names.SAME, + false, + true + ); + requestHandlers.registerHandler(registry); + String requestValue = randomAlphaOfLength(10); + OutboundMessage.Request request = new OutboundMessage.Request( + threadPool.getThreadContext(), + new String[0], + new TestRequest(requestValue), + version, + action, + requestId, + false, + false + ); + + BytesReference fullRequestBytes = request.serialize(new BytesStreamOutput()); + BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); + Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); + InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); + requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); + handler.inboundMessage(channel, requestMessage); + + TransportChannel transportChannel = channelCaptor.get(); + assertEquals(Version.CURRENT, transportChannel.getVersion()); + assertEquals("transport", transportChannel.getChannelType()); + assertEquals(requestValue, requestCaptor.get().value); + + String responseValue = randomAlphaOfLength(10); + byte responseStatus = TransportStatus.setResponse((byte) 0); + transportChannel.sendResponse(new TestResponse(responseValue)); + + BytesReference fullResponseBytes = channel.getMessageCaptor().get(); + // Create the response payload by intentionally stripping 1 byte away + BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize - 1); + Header responseHeader = new Header(fullResponseBytes.length() - 6, requestId, responseStatus, version); + InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {}); + responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput()); + handler.inboundMessage(channel, responseMessage); + + assertThat(exceptionCaptor.get(), instanceOf(RemoteTransportException.class)); + assertThat(exceptionCaptor.get().getCause(), instanceOf(TransportSerializationException.class)); + assertThat(exceptionCaptor.get().getMessage(), containsString("Failed to deserialize response from handler")); + } + private static InboundMessage unreadableInboundHandshake(Version remoteVersion, Header requestHeader) { return new InboundMessage(requestHeader, ReleasableBytesReference.wrap(BytesArray.EMPTY), () -> {}) { @Override