[BUG] Serialization bugs can cause node drops (#1885)
This commit restructures InboundHandler to ensure all data is consumed over the wire. Signed-off-by: Andriy Redko <andriy.redko@aiven.io>
This commit is contained in:
parent
e7d44c20e9
commit
f059738aac
|
@ -47,6 +47,7 @@ import org.opensearch.common.util.concurrent.AbstractRunnable;
|
|||
import org.opensearch.common.util.concurrent.ThreadContext;
|
||||
import org.opensearch.threadpool.ThreadPool;
|
||||
|
||||
import java.io.EOFException;
|
||||
import java.io.IOException;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.nio.ByteBuffer;
|
||||
|
@ -149,27 +150,13 @@ public class InboundHandler {
|
|||
streamInput = namedWriteableStream(message.openOrGetStreamInput());
|
||||
assertRemoteVersion(streamInput, header.getVersion());
|
||||
if (header.isError()) {
|
||||
handlerResponseError(streamInput, handler);
|
||||
handlerResponseError(requestId, streamInput, handler);
|
||||
} else {
|
||||
handleResponse(remoteAddress, streamInput, handler);
|
||||
}
|
||||
// Check the entire message has been read
|
||||
final int nextByte = streamInput.read();
|
||||
// calling read() is useful to make sure the message is fully read, even if there is an EOS marker
|
||||
if (nextByte != -1) {
|
||||
throw new IllegalStateException(
|
||||
"Message not fully read (response) for requestId ["
|
||||
+ requestId
|
||||
+ "], handler ["
|
||||
+ handler
|
||||
+ "], error ["
|
||||
+ header.isError()
|
||||
+ "]; resetting"
|
||||
);
|
||||
handleResponse(requestId, remoteAddress, streamInput, handler);
|
||||
}
|
||||
} else {
|
||||
assert header.isError() == false;
|
||||
handleResponse(remoteAddress, EMPTY_STREAM_INPUT, handler);
|
||||
handleResponse(requestId, remoteAddress, EMPTY_STREAM_INPUT, handler);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -246,22 +233,11 @@ public class InboundHandler {
|
|||
assertRemoteVersion(stream, header.getVersion());
|
||||
final RequestHandlerRegistry<T> reg = requestHandlers.getHandler(action);
|
||||
assert reg != null;
|
||||
final T request = reg.newRequest(stream);
|
||||
|
||||
final T request = newRequest(requestId, action, stream, reg);
|
||||
request.remoteAddress(new TransportAddress(channel.getRemoteAddress()));
|
||||
// in case we throw an exception, i.e. when the limit is hit, we don't want to verify
|
||||
final int nextByte = stream.read();
|
||||
// calling read() is useful to make sure the message is fully read, even if there some kind of EOS marker
|
||||
if (nextByte != -1) {
|
||||
throw new IllegalStateException(
|
||||
"Message not fully read (request) for requestId ["
|
||||
+ requestId
|
||||
+ "], action ["
|
||||
+ action
|
||||
+ "], available ["
|
||||
+ stream.available()
|
||||
+ "]; resetting"
|
||||
);
|
||||
}
|
||||
checkStreamIsFullyConsumed(requestId, action, stream);
|
||||
|
||||
final String executor = reg.getExecutor();
|
||||
if (ThreadPool.Names.SAME.equals(executor)) {
|
||||
try {
|
||||
|
@ -279,6 +255,97 @@ public class InboundHandler {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates new request instance out of input stream. Throws IllegalStateException if the end of
|
||||
* the stream was reached before the request is fully deserialized from the stream.
|
||||
* @param <T> transport request type
|
||||
* @param requestId request identifier
|
||||
* @param action action name
|
||||
* @param stream stream
|
||||
* @param reg request handler registry
|
||||
* @return new request instance
|
||||
* @throws IOException IOException
|
||||
* @throws IllegalStateException IllegalStateException
|
||||
*/
|
||||
private <T extends TransportRequest> T newRequest(
|
||||
final long requestId,
|
||||
final String action,
|
||||
final StreamInput stream,
|
||||
final RequestHandlerRegistry<T> reg
|
||||
) throws IOException {
|
||||
try {
|
||||
return reg.newRequest(stream);
|
||||
} catch (final EOFException e) {
|
||||
// Another favor of (de)serialization issues is when stream contains less bytes than
|
||||
// the request handler needs to deserialize the payload.
|
||||
throw new IllegalStateException(
|
||||
"Message fully read (request) but more data is expected for requestId ["
|
||||
+ requestId
|
||||
+ "], action ["
|
||||
+ action
|
||||
+ "]; resetting",
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the stream is fully consumed and throws the exceptions if that is not the case.
|
||||
* @param requestId request identifier
|
||||
* @param action action name
|
||||
* @param stream stream
|
||||
* @throws IOException IOException
|
||||
*/
|
||||
private void checkStreamIsFullyConsumed(final long requestId, final String action, final StreamInput stream) throws IOException {
|
||||
// in case we throw an exception, i.e. when the limit is hit, we don't want to verify
|
||||
final int nextByte = stream.read();
|
||||
|
||||
// calling read() is useful to make sure the message is fully read, even if there some kind of EOS marker
|
||||
if (nextByte != -1) {
|
||||
throw new IllegalStateException(
|
||||
"Message not fully read (request) for requestId ["
|
||||
+ requestId
|
||||
+ "], action ["
|
||||
+ action
|
||||
+ "], available ["
|
||||
+ stream.available()
|
||||
+ "]; resetting"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the stream is fully consumed and throws the exceptions if that is not the case.
|
||||
* @param requestId request identifier
|
||||
* @param handler response handler
|
||||
* @param stream stream
|
||||
* @param error "true" if response represents error, "false" otherwise
|
||||
* @throws IOException IOException
|
||||
*/
|
||||
private void checkStreamIsFullyConsumed(
|
||||
final long requestId,
|
||||
final TransportResponseHandler<?> handler,
|
||||
final StreamInput stream,
|
||||
final boolean error
|
||||
) throws IOException {
|
||||
if (stream != EMPTY_STREAM_INPUT) {
|
||||
// Check the entire message has been read
|
||||
final int nextByte = stream.read();
|
||||
// calling read() is useful to make sure the message is fully read, even if there is an EOS marker
|
||||
if (nextByte != -1) {
|
||||
throw new IllegalStateException(
|
||||
"Message not fully read (response) for requestId ["
|
||||
+ requestId
|
||||
+ "], handler ["
|
||||
+ handler
|
||||
+ "], error ["
|
||||
+ error
|
||||
+ "]; resetting"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void sendErrorResponse(String actionName, TransportChannel transportChannel, Exception e) {
|
||||
try {
|
||||
transportChannel.sendResponse(e);
|
||||
|
@ -289,6 +356,7 @@ public class InboundHandler {
|
|||
}
|
||||
|
||||
private <T extends TransportResponse> void handleResponse(
|
||||
final long requestId,
|
||||
InetSocketAddress remoteAddress,
|
||||
final StreamInput stream,
|
||||
final TransportResponseHandler<T> handler
|
||||
|
@ -297,6 +365,7 @@ public class InboundHandler {
|
|||
try {
|
||||
response = handler.read(stream);
|
||||
response.remoteAddress(new TransportAddress(remoteAddress));
|
||||
checkStreamIsFullyConsumed(requestId, handler, stream, false);
|
||||
} catch (Exception e) {
|
||||
final Exception serializationException = new TransportSerializationException(
|
||||
"Failed to deserialize response from handler [" + handler + "]",
|
||||
|
@ -322,10 +391,11 @@ public class InboundHandler {
|
|||
}
|
||||
}
|
||||
|
||||
private void handlerResponseError(StreamInput stream, final TransportResponseHandler<?> handler) {
|
||||
private void handlerResponseError(final long requestId, StreamInput stream, final TransportResponseHandler<?> handler) {
|
||||
Exception error;
|
||||
try {
|
||||
error = stream.readException();
|
||||
checkStreamIsFullyConsumed(requestId, handler, stream, true);
|
||||
} catch (Exception e) {
|
||||
error = new TransportSerializationException(
|
||||
"Failed to deserialize exception response from stream for handler [" + handler + "]",
|
||||
|
|
|
@ -34,6 +34,7 @@ package org.opensearch.transport;
|
|||
|
||||
import org.apache.logging.log4j.Level;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.opensearch.OpenSearchException;
|
||||
import org.opensearch.Version;
|
||||
import org.opensearch.action.ActionListener;
|
||||
|
@ -57,13 +58,17 @@ import org.opensearch.threadpool.ThreadPool;
|
|||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.EOFException;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Collections;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.containsString;
|
||||
import static org.hamcrest.CoreMatchers.startsWith;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
|
||||
public class InboundHandlerTests extends OpenSearchTestCase {
|
||||
|
@ -75,16 +80,24 @@ public class InboundHandlerTests extends OpenSearchTestCase {
|
|||
private Transport.ResponseHandlers responseHandlers;
|
||||
private Transport.RequestHandlers requestHandlers;
|
||||
private InboundHandler handler;
|
||||
private OutboundHandler outboundHandler;
|
||||
private FakeTcpChannel channel;
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
super.setUp();
|
||||
taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet());
|
||||
channel = new FakeTcpChannel(randomBoolean(), buildNewFakeTransportAddress().address(), buildNewFakeTransportAddress().address());
|
||||
channel = new FakeTcpChannel(randomBoolean(), buildNewFakeTransportAddress().address(), buildNewFakeTransportAddress().address()) {
|
||||
public void sendMessage(BytesReference reference, org.opensearch.action.ActionListener<Void> listener) {
|
||||
super.sendMessage(reference, listener);
|
||||
if (listener != null) {
|
||||
listener.onResponse(null);
|
||||
}
|
||||
}
|
||||
};
|
||||
NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
|
||||
TransportHandshaker handshaker = new TransportHandshaker(version, threadPool, (n, c, r, v) -> {});
|
||||
OutboundHandler outboundHandler = new OutboundHandler(
|
||||
outboundHandler = new OutboundHandler(
|
||||
"node",
|
||||
version,
|
||||
new String[0],
|
||||
|
@ -211,7 +224,7 @@ public class InboundHandlerTests extends OpenSearchTestCase {
|
|||
|
||||
BytesReference fullResponseBytes = channel.getMessageCaptor().get();
|
||||
BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize);
|
||||
Header responseHeader = new Header(fullRequestBytes.length() - 6, requestId, responseStatus, version);
|
||||
Header responseHeader = new Header(fullResponseBytes.length() - 6, requestId, responseStatus, version);
|
||||
InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {});
|
||||
responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput());
|
||||
handler.inboundMessage(channel, responseMessage);
|
||||
|
@ -326,6 +339,317 @@ public class InboundHandlerTests extends OpenSearchTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testRequestNotFullyRead() throws Exception {
|
||||
String action = "test-request";
|
||||
int headerSize = TcpHeader.headerSize(version);
|
||||
AtomicReference<Exception> exceptionCaptor = new AtomicReference<>();
|
||||
|
||||
long requestId = responseHandlers.add(new Transport.ResponseContext<>(new TransportResponseHandler<TestResponse>() {
|
||||
@Override
|
||||
public void handleResponse(TestResponse response) {}
|
||||
|
||||
@Override
|
||||
public void handleException(TransportException exp) {
|
||||
exceptionCaptor.set(exp);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String executor() {
|
||||
return ThreadPool.Names.SAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TestResponse read(StreamInput in) throws IOException {
|
||||
return new TestResponse(in);
|
||||
}
|
||||
}, null, action));
|
||||
|
||||
RequestHandlerRegistry<TestRequest> registry = new RequestHandlerRegistry<>(
|
||||
action,
|
||||
TestRequest::new,
|
||||
taskManager,
|
||||
(request, channel, task) -> {},
|
||||
ThreadPool.Names.SAME,
|
||||
false,
|
||||
true
|
||||
);
|
||||
|
||||
requestHandlers.registerHandler(registry);
|
||||
String requestValue = randomAlphaOfLength(10);
|
||||
OutboundMessage.Request request = new OutboundMessage.Request(
|
||||
threadPool.getThreadContext(),
|
||||
new String[0],
|
||||
new TestRequest(requestValue),
|
||||
version,
|
||||
action,
|
||||
requestId,
|
||||
false,
|
||||
false
|
||||
);
|
||||
|
||||
outboundHandler.setMessageListener(new TransportMessageListener() {
|
||||
@Override
|
||||
public void onResponseSent(long requestId, String action, Exception error) {
|
||||
exceptionCaptor.set(error);
|
||||
}
|
||||
});
|
||||
|
||||
// Create the request payload with 1 byte overflow
|
||||
final BytesRef bytes = request.serialize(new BytesStreamOutput()).toBytesRef();
|
||||
final ByteBuffer buffer = ByteBuffer.allocate(bytes.length + 1);
|
||||
buffer.put(bytes.bytes, 0, bytes.length);
|
||||
buffer.put((byte) 1);
|
||||
|
||||
BytesReference fullRequestBytes = BytesReference.fromByteBuffer((ByteBuffer) buffer.flip());
|
||||
BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize);
|
||||
Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version);
|
||||
InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {});
|
||||
requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput());
|
||||
handler.inboundMessage(channel, requestMessage);
|
||||
|
||||
assertThat(exceptionCaptor.get(), instanceOf(IllegalStateException.class));
|
||||
assertThat(exceptionCaptor.get().getMessage(), startsWith("Message not fully read (request) for requestId"));
|
||||
}
|
||||
|
||||
public void testRequestFullyReadButMoreDataIsAvailable() throws Exception {
|
||||
String action = "test-request";
|
||||
int headerSize = TcpHeader.headerSize(version);
|
||||
AtomicReference<Exception> exceptionCaptor = new AtomicReference<>();
|
||||
|
||||
long requestId = responseHandlers.add(new Transport.ResponseContext<>(new TransportResponseHandler<TestResponse>() {
|
||||
@Override
|
||||
public void handleResponse(TestResponse response) {}
|
||||
|
||||
@Override
|
||||
public void handleException(TransportException exp) {
|
||||
exceptionCaptor.set(exp);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String executor() {
|
||||
return ThreadPool.Names.SAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TestResponse read(StreamInput in) throws IOException {
|
||||
return new TestResponse(in);
|
||||
}
|
||||
}, null, action));
|
||||
|
||||
RequestHandlerRegistry<TestRequest> registry = new RequestHandlerRegistry<>(
|
||||
action,
|
||||
TestRequest::new,
|
||||
taskManager,
|
||||
(request, channel, task) -> {},
|
||||
ThreadPool.Names.SAME,
|
||||
false,
|
||||
true
|
||||
);
|
||||
|
||||
requestHandlers.registerHandler(registry);
|
||||
String requestValue = randomAlphaOfLength(10);
|
||||
OutboundMessage.Request request = new OutboundMessage.Request(
|
||||
threadPool.getThreadContext(),
|
||||
new String[0],
|
||||
new TestRequest(requestValue),
|
||||
version,
|
||||
action,
|
||||
requestId,
|
||||
false,
|
||||
false
|
||||
);
|
||||
|
||||
outboundHandler.setMessageListener(new TransportMessageListener() {
|
||||
@Override
|
||||
public void onResponseSent(long requestId, String action, Exception error) {
|
||||
exceptionCaptor.set(error);
|
||||
}
|
||||
});
|
||||
|
||||
final BytesReference fullRequestBytes = request.serialize(new BytesStreamOutput());
|
||||
// Create the request payload by intentionally stripping 1 byte away
|
||||
BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize - 1);
|
||||
Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version);
|
||||
InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {});
|
||||
requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput());
|
||||
handler.inboundMessage(channel, requestMessage);
|
||||
|
||||
assertThat(exceptionCaptor.get(), instanceOf(IllegalStateException.class));
|
||||
assertThat(exceptionCaptor.get().getCause(), instanceOf(EOFException.class));
|
||||
assertThat(exceptionCaptor.get().getMessage(), startsWith("Message fully read (request) but more data is expected for requestId"));
|
||||
}
|
||||
|
||||
public void testResponseNotFullyRead() throws Exception {
|
||||
String action = "test-request";
|
||||
int headerSize = TcpHeader.headerSize(version);
|
||||
AtomicReference<TestRequest> requestCaptor = new AtomicReference<>();
|
||||
AtomicReference<Exception> exceptionCaptor = new AtomicReference<>();
|
||||
AtomicReference<TestResponse> responseCaptor = new AtomicReference<>();
|
||||
AtomicReference<TransportChannel> channelCaptor = new AtomicReference<>();
|
||||
|
||||
long requestId = responseHandlers.add(new Transport.ResponseContext<>(new TransportResponseHandler<TestResponse>() {
|
||||
@Override
|
||||
public void handleResponse(TestResponse response) {
|
||||
responseCaptor.set(response);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleException(TransportException exp) {
|
||||
exceptionCaptor.set(exp);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String executor() {
|
||||
return ThreadPool.Names.SAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TestResponse read(StreamInput in) throws IOException {
|
||||
return new TestResponse(in);
|
||||
}
|
||||
}, null, action));
|
||||
RequestHandlerRegistry<TestRequest> registry = new RequestHandlerRegistry<>(
|
||||
action,
|
||||
TestRequest::new,
|
||||
taskManager,
|
||||
(request, channel, task) -> {
|
||||
channelCaptor.set(channel);
|
||||
requestCaptor.set(request);
|
||||
},
|
||||
ThreadPool.Names.SAME,
|
||||
false,
|
||||
true
|
||||
);
|
||||
requestHandlers.registerHandler(registry);
|
||||
String requestValue = randomAlphaOfLength(10);
|
||||
OutboundMessage.Request request = new OutboundMessage.Request(
|
||||
threadPool.getThreadContext(),
|
||||
new String[0],
|
||||
new TestRequest(requestValue),
|
||||
version,
|
||||
action,
|
||||
requestId,
|
||||
false,
|
||||
false
|
||||
);
|
||||
|
||||
BytesReference fullRequestBytes = request.serialize(new BytesStreamOutput());
|
||||
BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize);
|
||||
Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version);
|
||||
InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {});
|
||||
requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput());
|
||||
handler.inboundMessage(channel, requestMessage);
|
||||
|
||||
TransportChannel transportChannel = channelCaptor.get();
|
||||
assertEquals(Version.CURRENT, transportChannel.getVersion());
|
||||
assertEquals("transport", transportChannel.getChannelType());
|
||||
assertEquals(requestValue, requestCaptor.get().value);
|
||||
|
||||
String responseValue = randomAlphaOfLength(10);
|
||||
byte responseStatus = TransportStatus.setResponse((byte) 0);
|
||||
transportChannel.sendResponse(new TestResponse(responseValue));
|
||||
|
||||
// Create the response payload with 1 byte overflow
|
||||
final BytesRef bytes = channel.getMessageCaptor().get().toBytesRef();
|
||||
final ByteBuffer buffer = ByteBuffer.allocate(bytes.length + 1);
|
||||
buffer.put(bytes.bytes, 0, bytes.length);
|
||||
buffer.put((byte) 1);
|
||||
|
||||
BytesReference fullResponseBytes = BytesReference.fromByteBuffer((ByteBuffer) buffer.flip());
|
||||
BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize);
|
||||
Header responseHeader = new Header(fullResponseBytes.length() - 6, requestId, responseStatus, version);
|
||||
InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {});
|
||||
responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput());
|
||||
handler.inboundMessage(channel, responseMessage);
|
||||
|
||||
assertThat(exceptionCaptor.get(), instanceOf(RemoteTransportException.class));
|
||||
assertThat(exceptionCaptor.get().getCause(), instanceOf(TransportSerializationException.class));
|
||||
assertThat(exceptionCaptor.get().getMessage(), containsString("Failed to deserialize response from handler"));
|
||||
}
|
||||
|
||||
public void testResponseFullyReadButMoreDataIsAvailable() throws Exception {
|
||||
String action = "test-request";
|
||||
int headerSize = TcpHeader.headerSize(version);
|
||||
AtomicReference<TestRequest> requestCaptor = new AtomicReference<>();
|
||||
AtomicReference<Exception> exceptionCaptor = new AtomicReference<>();
|
||||
AtomicReference<TestResponse> responseCaptor = new AtomicReference<>();
|
||||
AtomicReference<TransportChannel> channelCaptor = new AtomicReference<>();
|
||||
|
||||
long requestId = responseHandlers.add(new Transport.ResponseContext<>(new TransportResponseHandler<TestResponse>() {
|
||||
@Override
|
||||
public void handleResponse(TestResponse response) {
|
||||
responseCaptor.set(response);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleException(TransportException exp) {
|
||||
exceptionCaptor.set(exp);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String executor() {
|
||||
return ThreadPool.Names.SAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TestResponse read(StreamInput in) throws IOException {
|
||||
return new TestResponse(in);
|
||||
}
|
||||
}, null, action));
|
||||
RequestHandlerRegistry<TestRequest> registry = new RequestHandlerRegistry<>(
|
||||
action,
|
||||
TestRequest::new,
|
||||
taskManager,
|
||||
(request, channel, task) -> {
|
||||
channelCaptor.set(channel);
|
||||
requestCaptor.set(request);
|
||||
},
|
||||
ThreadPool.Names.SAME,
|
||||
false,
|
||||
true
|
||||
);
|
||||
requestHandlers.registerHandler(registry);
|
||||
String requestValue = randomAlphaOfLength(10);
|
||||
OutboundMessage.Request request = new OutboundMessage.Request(
|
||||
threadPool.getThreadContext(),
|
||||
new String[0],
|
||||
new TestRequest(requestValue),
|
||||
version,
|
||||
action,
|
||||
requestId,
|
||||
false,
|
||||
false
|
||||
);
|
||||
|
||||
BytesReference fullRequestBytes = request.serialize(new BytesStreamOutput());
|
||||
BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize);
|
||||
Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version);
|
||||
InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {});
|
||||
requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput());
|
||||
handler.inboundMessage(channel, requestMessage);
|
||||
|
||||
TransportChannel transportChannel = channelCaptor.get();
|
||||
assertEquals(Version.CURRENT, transportChannel.getVersion());
|
||||
assertEquals("transport", transportChannel.getChannelType());
|
||||
assertEquals(requestValue, requestCaptor.get().value);
|
||||
|
||||
String responseValue = randomAlphaOfLength(10);
|
||||
byte responseStatus = TransportStatus.setResponse((byte) 0);
|
||||
transportChannel.sendResponse(new TestResponse(responseValue));
|
||||
|
||||
BytesReference fullResponseBytes = channel.getMessageCaptor().get();
|
||||
// Create the response payload by intentionally stripping 1 byte away
|
||||
BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize - 1);
|
||||
Header responseHeader = new Header(fullResponseBytes.length() - 6, requestId, responseStatus, version);
|
||||
InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {});
|
||||
responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput());
|
||||
handler.inboundMessage(channel, responseMessage);
|
||||
|
||||
assertThat(exceptionCaptor.get(), instanceOf(RemoteTransportException.class));
|
||||
assertThat(exceptionCaptor.get().getCause(), instanceOf(TransportSerializationException.class));
|
||||
assertThat(exceptionCaptor.get().getMessage(), containsString("Failed to deserialize response from handler"));
|
||||
}
|
||||
|
||||
private static InboundMessage unreadableInboundHandshake(Version remoteVersion, Header requestHeader) {
|
||||
return new InboundMessage(requestHeader, ReleasableBytesReference.wrap(BytesArray.EMPTY), () -> {}) {
|
||||
@Override
|
||||
|
|
Loading…
Reference in New Issue