Move outbound message handling to OutboundHandler (#40336)

Currently there are some components of message serializer and sending
that still occur in TcpTransport. This commit makes it possible to
send a message without the TcpTransport by moving all of the remaining
application logic to the OutboundHandler. Additionally, it adds unit
tests to ensure that this logic works as expected.
This commit is contained in:
Tim Brooks 2019-03-22 13:58:30 -06:00
parent 13d4d73ce3
commit 3860ddd1a4
No known key found for this signature in database
GPG Key ID: C2AA3BB91A889E77
12 changed files with 334 additions and 185 deletions

View File

@ -111,7 +111,7 @@ public class Netty4TransportIT extends ESNetty4IntegTestCase {
}
@Override
protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage request, int messageLengthBytes) throws IOException {
protected void handleRequest(TcpChannel channel, InboundMessage.Request request, int messageLengthBytes) throws IOException {
super.handleRequest(channel, request, messageLengthBytes);
channelProfileName = TransportSettings.DEFAULT_PROFILE;
}

View File

@ -113,7 +113,7 @@ public class NioTransportIT extends NioIntegTestCase {
}
@Override
protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage request, int messageLengthBytes) throws IOException {
protected void handleRequest(TcpChannel channel, InboundMessage.Request request, int messageLengthBytes) throws IOException {
super.handleRequest(channel, request, messageLengthBytes);
channelProfileName = TransportSettings.DEFAULT_PROFILE;
}

View File

@ -106,9 +106,9 @@ public abstract class InboundMessage extends NetworkMessage implements Closeable
features = Collections.emptySet();
}
final String action = streamInput.readString();
message = new RequestMessage(threadContext, remoteVersion, status, requestId, action, features, streamInput);
message = new Request(threadContext, remoteVersion, status, requestId, action, features, streamInput);
} else {
message = new ResponseMessage(threadContext, remoteVersion, status, requestId, streamInput);
message = new Response(threadContext, remoteVersion, status, requestId, streamInput);
}
success = true;
return message;
@ -138,12 +138,12 @@ public abstract class InboundMessage extends NetworkMessage implements Closeable
}
}
public static class RequestMessage extends InboundMessage {
public static class Request extends InboundMessage {
private final String actionName;
private final Set<String> features;
RequestMessage(ThreadContext threadContext, Version version, byte status, long requestId, String actionName, Set<String> features,
Request(ThreadContext threadContext, Version version, byte status, long requestId, String actionName, Set<String> features,
StreamInput streamInput) {
super(threadContext, version, status, requestId, streamInput);
this.actionName = actionName;
@ -159,9 +159,9 @@ public abstract class InboundMessage extends NetworkMessage implements Closeable
}
}
public static class ResponseMessage extends InboundMessage {
public static class Response extends InboundMessage {
ResponseMessage(ThreadContext threadContext, Version version, byte status, long requestId, StreamInput streamInput) {
Response(ThreadContext threadContext, Version version, byte status, long requestId, StreamInput streamInput) {
super(threadContext, version, status, requestId, streamInput);
}
}

View File

@ -22,8 +22,10 @@ package org.elasticsearch.transport;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.NotifyOnceListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.CheckedSupplier;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput;
@ -32,49 +34,100 @@ import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.metrics.MeanMetric;
import org.elasticsearch.common.network.CloseableChannel;
import org.elasticsearch.common.transport.NetworkExceptionHelper;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.core.internal.io.IOUtils;
import org.elasticsearch.threadpool.ThreadPool;
import java.io.IOException;
import java.util.Set;
final class OutboundHandler {
private static final Logger logger = LogManager.getLogger(OutboundHandler.class);
private final MeanMetric transmittedBytesMetric = new MeanMetric();
private final String nodeName;
private final Version version;
private final String[] features;
private final ThreadPool threadPool;
private final BigArrays bigArrays;
private final TransportLogger transportLogger;
private volatile TransportMessageListener messageListener = TransportMessageListener.NOOP_LISTENER;
OutboundHandler(ThreadPool threadPool, BigArrays bigArrays, TransportLogger transportLogger) {
OutboundHandler(String nodeName, Version version, String[] features, ThreadPool threadPool, BigArrays bigArrays,
TransportLogger transportLogger) {
this.nodeName = nodeName;
this.version = version;
this.features = features;
this.threadPool = threadPool;
this.bigArrays = bigArrays;
this.transportLogger = transportLogger;
}
void sendBytes(TcpChannel channel, BytesReference bytes, ActionListener<Void> listener) {
channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis());
SendContext sendContext = new SendContext(channel, () -> bytes, listener);
try {
internalSendMessage(channel, sendContext);
internalSend(channel, sendContext);
} catch (IOException e) {
// This should not happen as the bytes are already serialized
throw new AssertionError(e);
}
}
void sendMessage(TcpChannel channel, OutboundMessage networkMessage, ActionListener<Void> listener) throws IOException {
channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis());
MessageSerializer serializer = new MessageSerializer(networkMessage, bigArrays);
SendContext sendContext = new SendContext(channel, serializer, listener, serializer);
internalSendMessage(channel, sendContext);
/**
* Sends the request to the given channel. This method should be used to send {@link TransportRequest}
* objects back to the caller.
*/
void sendRequest(final DiscoveryNode node, final TcpChannel channel, final long requestId, final String action,
final TransportRequest request, final TransportRequestOptions options, final Version channelVersion,
final boolean compressRequest, final boolean isHandshake) throws IOException, TransportException {
Version version = Version.min(this.version, channelVersion);
OutboundMessage.Request message = new OutboundMessage.Request(threadPool.getThreadContext(), features, request, version, action,
requestId, isHandshake, compressRequest);
ActionListener<Void> listener = ActionListener.wrap(() ->
messageListener.onRequestSent(node, requestId, action, request, options));
sendMessage(channel, message, listener);
}
/**
* sends a message to the given channel, using the given callbacks.
* Sends the response to the given channel. This method should be used to send {@link TransportResponse}
* objects back to the caller.
*
* @see #sendErrorResponse(Version, Set, TcpChannel, long, String, Exception) for sending error responses
*/
private void internalSendMessage(TcpChannel channel, SendContext sendContext) throws IOException {
void sendResponse(final Version nodeVersion, final Set<String> features, final TcpChannel channel,
final long requestId, final String action, final TransportResponse response,
final boolean compress, final boolean isHandshake) throws IOException {
Version version = Version.min(this.version, nodeVersion);
OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, response, version,
requestId, isHandshake, compress);
ActionListener<Void> listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, response));
sendMessage(channel, message, listener);
}
/**
* Sends back an error response to the caller via the given channel
*/
void sendErrorResponse(final Version nodeVersion, final Set<String> features, final TcpChannel channel, final long requestId,
final String action, final Exception error) throws IOException {
Version version = Version.min(this.version, nodeVersion);
TransportAddress address = new TransportAddress(channel.getLocalAddress());
RemoteTransportException tx = new RemoteTransportException(nodeName, address, action, error);
OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, tx, version, requestId,
false, false);
ActionListener<Void> listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, error));
sendMessage(channel, message, listener);
}
private void sendMessage(TcpChannel channel, OutboundMessage networkMessage, ActionListener<Void> listener) throws IOException {
MessageSerializer serializer = new MessageSerializer(networkMessage, bigArrays);
SendContext sendContext = new SendContext(channel, serializer, listener, serializer);
internalSend(channel, sendContext);
}
private void internalSend(TcpChannel channel, SendContext sendContext) throws IOException {
channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis());
BytesReference reference = sendContext.get();
try {
@ -91,6 +144,14 @@ final class OutboundHandler {
return transmittedBytesMetric;
}
void setMessageListener(TransportMessageListener listener) {
if (messageListener == TransportMessageListener.NOOP_LISTENER) {
messageListener = listener;
} else {
throw new IllegalStateException("Cannot set message listener twice");
}
}
private static class MessageSerializer implements CheckedSupplier<BytesReference, IOException>, Releasable {
private final OutboundMessage message;

View File

@ -106,19 +106,15 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
private static final long NINETY_PER_HEAP_SIZE = (long) (JvmInfo.jvmInfo().getMem().getHeapMax().getBytes() * 0.9);
private static final BytesReference EMPTY_BYTES_REFERENCE = new BytesArray(new byte[0]);
private final String[] features;
protected final Settings settings;
private final CircuitBreakerService circuitBreakerService;
private final Version version;
protected final ThreadPool threadPool;
protected final BigArrays bigArrays;
protected final PageCacheRecycler pageCacheRecycler;
protected final NetworkService networkService;
protected final Set<ProfileSettings> profileSettings;
private static final TransportMessageListener NOOP_LISTENER = new TransportMessageListener() {};
private volatile TransportMessageListener messageListener = NOOP_LISTENER;
private volatile TransportMessageListener messageListener = TransportMessageListener.NOOP_LISTENER;
private final ConcurrentMap<String, BoundTransportAddress> profileBoundAddresses = newConcurrentMap();
private final Map<String, List<TcpServerChannel>> serverChannels = newConcurrentMap();
@ -137,34 +133,23 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
private final TransportKeepAlive keepAlive;
private final InboundMessage.Reader reader;
private final OutboundHandler outboundHandler;
private final String nodeName;
public TcpTransport(Settings settings, Version version, ThreadPool threadPool, PageCacheRecycler pageCacheRecycler,
CircuitBreakerService circuitBreakerService, NamedWriteableRegistry namedWriteableRegistry,
NetworkService networkService) {
this.settings = settings;
this.profileSettings = getProfileSettings(settings);
this.version = version;
this.threadPool = threadPool;
this.bigArrays = new BigArrays(pageCacheRecycler, circuitBreakerService, CircuitBreaker.IN_FLIGHT_REQUESTS);
this.pageCacheRecycler = pageCacheRecycler;
this.circuitBreakerService = circuitBreakerService;
this.networkService = networkService;
this.transportLogger = new TransportLogger();
this.outboundHandler = new OutboundHandler(threadPool, bigArrays, transportLogger);
this.handshaker = new TransportHandshaker(version, threadPool,
(node, channel, requestId, v) -> sendRequestToChannel(node, channel, requestId,
TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version),
TransportRequestOptions.EMPTY, v, false, true),
(v, features, channel, response, requestId) -> sendResponse(v, features, channel, response, requestId,
TransportHandshaker.HANDSHAKE_ACTION_NAME, false, true));
this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes);
this.reader = new InboundMessage.Reader(version, namedWriteableRegistry, threadPool.getThreadContext());
this.nodeName = Node.NODE_NAME_SETTING.get(settings);
String nodeName = Node.NODE_NAME_SETTING.get(settings);
final Settings defaultFeatures = TransportSettings.DEFAULT_FEATURES_SETTING.get(settings);
String[] features;
if (defaultFeatures == null) {
this.features = new String[0];
features = new String[0];
} else {
defaultFeatures.names().forEach(key -> {
if (Booleans.parseBoolean(defaultFeatures.get(key)) == false) {
@ -172,8 +157,18 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
}
});
// use a sorted set to present the features in a consistent order
this.features = new TreeSet<>(defaultFeatures.names()).toArray(new String[defaultFeatures.names().size()]);
features = new TreeSet<>(defaultFeatures.names()).toArray(new String[defaultFeatures.names().size()]);
}
this.outboundHandler = new OutboundHandler(nodeName, version, features, threadPool, bigArrays, transportLogger);
this.handshaker = new TransportHandshaker(version, threadPool,
(node, channel, requestId, v) -> outboundHandler.sendRequest(node, channel, requestId,
TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version),
TransportRequestOptions.EMPTY, v, false, true),
(v, features1, channel, response, requestId) -> outboundHandler.sendResponse(v, features1, channel, requestId,
TransportHandshaker.HANDSHAKE_ACTION_NAME, response, false, true));
this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes);
this.reader = new InboundMessage.Reader(version, namedWriteableRegistry, threadPool.getThreadContext());
}
@Override
@ -182,8 +177,9 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
@Override
public synchronized void setMessageListener(TransportMessageListener listener) {
if (messageListener == NOOP_LISTENER) {
if (messageListener == TransportMessageListener.NOOP_LISTENER) {
messageListener = listener;
outboundHandler.setMessageListener(listener);
} else {
throw new IllegalStateException("Cannot set message listener twice");
}
@ -267,7 +263,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
throw new NodeNotConnectedException(node, "connection already closed");
}
TcpChannel channel = channel(options.type());
sendRequestToChannel(this.node, channel, requestId, action, request, options, getVersion(), compress);
outboundHandler.sendRequest(node, channel, requestId, action, request, options, getVersion(), compress, false);
}
}
@ -661,81 +657,6 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
*/
protected abstract void stopInternal();
private void sendRequestToChannel(final DiscoveryNode node, final TcpChannel channel, final long requestId, final String action,
final TransportRequest request, TransportRequestOptions options, Version channelVersion,
boolean compressRequest) throws IOException, TransportException {
sendRequestToChannel(node, channel, requestId, action, request, options, channelVersion, compressRequest, false);
}
private void sendRequestToChannel(final DiscoveryNode node, final TcpChannel channel, final long requestId, final String action,
final TransportRequest request, TransportRequestOptions options, Version channelVersion,
boolean compressRequest, boolean isHandshake) throws IOException, TransportException {
Version version = Version.min(this.version, channelVersion);
OutboundMessage.Request message = new OutboundMessage.Request(threadPool.getThreadContext(), features, request, version, action,
requestId, isHandshake, compressRequest);
ActionListener<Void> listener = ActionListener.wrap(() ->
messageListener.onRequestSent(node, requestId, action, request, options));
outboundHandler.sendMessage(channel, message, listener);
}
/**
* Sends back an error response to the caller via the given channel
*
* @param nodeVersion the caller node version
* @param features the caller features
* @param channel the channel to send the response to
* @param error the error to return
* @param requestId the request ID this response replies to
* @param action the action this response replies to
*/
public void sendErrorResponse(
final Version nodeVersion,
final Set<String> features,
final TcpChannel channel,
final Exception error,
final long requestId,
final String action) throws IOException {
Version version = Version.min(this.version, nodeVersion);
TransportAddress address = new TransportAddress(channel.getLocalAddress());
RemoteTransportException tx = new RemoteTransportException(nodeName, address, action, error);
OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, tx, version, requestId,
false, false);
ActionListener<Void> listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, error));
outboundHandler.sendMessage(channel, message, listener);
}
/**
* Sends the response to the given channel. This method should be used to send {@link TransportResponse} objects back to the caller.
*
* @see #sendErrorResponse(Version, Set, TcpChannel, Exception, long, String) for sending back errors to the caller
*/
public void sendResponse(
final Version nodeVersion,
final Set<String> features,
final TcpChannel channel,
final TransportResponse response,
final long requestId,
final String action,
final boolean compress) throws IOException {
sendResponse(nodeVersion, features, channel, response, requestId, action, compress, false);
}
private void sendResponse(
final Version nodeVersion,
final Set<String> features,
final TcpChannel channel,
final TransportResponse response,
final long requestId,
final String action,
boolean compress,
boolean isHandshake) throws IOException {
Version version = Version.min(this.version, nodeVersion);
OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, response, version,
requestId, isHandshake, compress);
ActionListener<Void> listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, response));
outboundHandler.sendMessage(channel, message, listener);
}
/**
* Handles inbound message that has been decoded.
*
@ -913,7 +834,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
message.getStoredContext().restore();
threadContext.putTransient("_remote_address", remoteAddress);
if (message.isRequest()) {
handleRequest(channel, (InboundMessage.RequestMessage) message, reference.length());
handleRequest(channel, (InboundMessage.Request) message, reference.length());
} else {
final TransportResponseHandler<?> handler;
long requestId = message.getRequestId();
@ -999,7 +920,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
});
}
protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage message, int messageLengthBytes) throws IOException {
protected void handleRequest(TcpChannel channel, InboundMessage.Request message, int messageLengthBytes) throws IOException {
final Set<String> features = message.getFeatures();
final String profileName = channel.getProfile();
final String action = message.getActionName();
@ -1021,8 +942,8 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
} else {
getInFlightRequestBreaker().addWithoutBreaking(messageLengthBytes);
}
transportChannel = new TcpTransportChannel(this, channel, action, requestId, version, features, profileName,
messageLengthBytes, message.isCompress());
transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features,
circuitBreakerService, messageLengthBytes, message.isCompress());
final TransportRequest request = reg.newRequest(stream);
request.remoteAddress(new TransportAddress(channel.getRemoteAddress()));
// in case we throw an exception, i.e. when the limit is hit, we don't want to verify
@ -1032,8 +953,8 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
} catch (Exception e) {
// the circuit breaker tripped
if (transportChannel == null) {
transportChannel = new TcpTransportChannel(this, channel, action, requestId, version, features,
profileName, 0, message.isCompress());
transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features,
circuitBreakerService, 0, message.isCompress());
}
try {
transportChannel.sendResponse(e);

View File

@ -20,6 +20,8 @@
package org.elasticsearch.transport;
import org.elasticsearch.Version;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import java.io.IOException;
import java.util.Set;
@ -28,38 +30,38 @@ import java.util.concurrent.atomic.AtomicBoolean;
public final class TcpTransportChannel implements TransportChannel {
private final AtomicBoolean released = new AtomicBoolean();
private final TcpTransport transport;
private final Version version;
private final Set<String> features;
private final OutboundHandler outboundHandler;
private final TcpChannel channel;
private final String action;
private final long requestId;
private final String profileName;
private final Version version;
private final Set<String> features;
private final CircuitBreakerService breakerService;
private final long reservedBytes;
private final TcpChannel channel;
private final boolean compressResponse;
TcpTransportChannel(TcpTransport transport, TcpChannel channel, String action, long requestId, Version version, Set<String> features,
String profileName, long reservedBytes, boolean compressResponse) {
TcpTransportChannel(OutboundHandler outboundHandler, TcpChannel channel, String action, long requestId, Version version,
Set<String> features, CircuitBreakerService breakerService, long reservedBytes, boolean compressResponse) {
this.version = version;
this.features = features;
this.channel = channel;
this.transport = transport;
this.outboundHandler = outboundHandler;
this.action = action;
this.requestId = requestId;
this.profileName = profileName;
this.breakerService = breakerService;
this.reservedBytes = reservedBytes;
this.compressResponse = compressResponse;
}
@Override
public String getProfileName() {
return profileName;
return channel.getProfile();
}
@Override
public void sendResponse(TransportResponse response) throws IOException {
try {
transport.sendResponse(version, features, channel, response, requestId, action, compressResponse);
outboundHandler.sendResponse(version, features, channel, requestId, action, response, compressResponse, false);
} finally {
release(false);
}
@ -68,7 +70,7 @@ public final class TcpTransportChannel implements TransportChannel {
@Override
public void sendResponse(Exception exception) throws IOException {
try {
transport.sendErrorResponse(version, features, channel, exception, requestId, action);
outboundHandler.sendErrorResponse(version, features, channel, requestId, action, exception);
} finally {
release(true);
}
@ -79,7 +81,7 @@ public final class TcpTransportChannel implements TransportChannel {
private void release(boolean isExceptionResponse) {
if (released.compareAndSet(false, true)) {
assert (releaseBy = new Exception()) != null; // easier to debug if it's already closed
transport.getInFlightRequestBreaker().addWithoutBreaking(-reservedBytes);
breakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS).addWithoutBreaking(-reservedBytes);
} else if (isExceptionResponse == false) {
// only fail if we are not sending an error - we might send the error triggered by the previous
// sendResponse call

View File

@ -30,7 +30,6 @@ import org.elasticsearch.common.transport.BoundTransportAddress;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.common.util.concurrent.ConcurrentMapLong;
import java.io.Closeable;
import java.io.IOException;
import java.net.UnknownHostException;

View File

@ -22,6 +22,8 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
public interface TransportMessageListener {
TransportMessageListener NOOP_LISTENER = new TransportMessageListener() {};
/**
* Called once a request is received
* @param requestId the internal request ID

View File

@ -63,7 +63,7 @@ public class InboundMessageTests extends ESTestCase {
InboundMessage.Reader reader = new InboundMessage.Reader(version, registry, threadContext);
BytesReference sliced = reference.slice(6, reference.length() - 6);
InboundMessage.RequestMessage inboundMessage = (InboundMessage.RequestMessage) reader.deserialize(sliced);
InboundMessage.Request inboundMessage = (InboundMessage.Request) reader.deserialize(sliced);
// Check that deserialize does not overwrite current thread context.
assertEquals("header_value2", threadContext.getHeader("header"));
inboundMessage.getStoredContext().restore();
@ -102,7 +102,7 @@ public class InboundMessageTests extends ESTestCase {
InboundMessage.Reader reader = new InboundMessage.Reader(version, registry, threadContext);
BytesReference sliced = reference.slice(6, reference.length() - 6);
InboundMessage.ResponseMessage inboundMessage = (InboundMessage.ResponseMessage) reader.deserialize(sliced);
InboundMessage.Response inboundMessage = (InboundMessage.Response) reader.deserialize(sliced);
// Check that deserialize does not overwrite current thread context.
assertEquals("header_value2", threadContext.getHeader("header"));
inboundMessage.getStoredContext().restore();
@ -138,7 +138,7 @@ public class InboundMessageTests extends ESTestCase {
InboundMessage.Reader reader = new InboundMessage.Reader(version, registry, threadContext);
BytesReference sliced = reference.slice(6, reference.length() - 6);
InboundMessage.ResponseMessage inboundMessage = (InboundMessage.ResponseMessage) reader.deserialize(sliced);
InboundMessage.Response inboundMessage = (InboundMessage.Response) reader.deserialize(sliced);
// Check that deserialize does not overwrite current thread context.
assertEquals("header_value2", threadContext.getHeader("header"));
inboundMessage.getStoredContext().restore();

View File

@ -19,14 +19,16 @@
package org.elasticsearch.transport;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.test.ESTestCase;
@ -38,24 +40,34 @@ import org.junit.Before;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.HashSet;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.instanceOf;
public class OutboundHandlerTests extends ESTestCase {
private final String feature1 = "feature1";
private final String feature2 = "feature2";
private final TestThreadPool threadPool = new TestThreadPool(getClass().getName());
private final NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
private final TransportRequestOptions options = TransportRequestOptions.EMPTY;
private OutboundHandler handler;
private FakeTcpChannel fakeTcpChannel;
private FakeTcpChannel channel;
private DiscoveryNode node;
@Before
public void setUp() throws Exception {
super.setUp();
TransportLogger transportLogger = new TransportLogger();
fakeTcpChannel = new FakeTcpChannel(randomBoolean());
handler = new OutboundHandler(threadPool, BigArrays.NON_RECYCLING_INSTANCE, transportLogger);
channel = new FakeTcpChannel(randomBoolean(), buildNewFakeTransportAddress().address(), buildNewFakeTransportAddress().address());
TransportAddress transportAddress = buildNewFakeTransportAddress();
node = new DiscoveryNode("", transportAddress, Version.CURRENT);
String[] features = {feature1, feature2};
handler = new OutboundHandler("node", Version.CURRENT, features, threadPool, BigArrays.NON_RECYCLING_INSTANCE, transportLogger);
}
@After
@ -70,10 +82,10 @@ public class OutboundHandlerTests extends ESTestCase {
AtomicBoolean isSuccess = new AtomicBoolean(false);
AtomicReference<Exception> exception = new AtomicReference<>();
ActionListener<Void> listener = ActionListener.wrap((v) -> isSuccess.set(true), exception::set);
handler.sendBytes(fakeTcpChannel, bytesArray, listener);
handler.sendBytes(channel, bytesArray, listener);
BytesReference reference = fakeTcpChannel.getMessageCaptor().get();
ActionListener<Void> sendListener = fakeTcpChannel.getListenerCaptor().get();
BytesReference reference = channel.getMessageCaptor().get();
ActionListener<Void> sendListener = channel.getListenerCaptor().get();
if (randomBoolean()) {
sendListener.onResponse(null);
assertTrue(isSuccess.get());
@ -88,55 +100,51 @@ public class OutboundHandlerTests extends ESTestCase {
assertEquals(bytesArray, reference);
}
public void testSendMessage() throws IOException {
OutboundMessage message;
public void testSendRequest() throws IOException {
ThreadContext threadContext = threadPool.getThreadContext();
Version version = Version.CURRENT;
String actionName = "handshake";
Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion());
String action = "handshake";
long requestId = randomLongBetween(0, 300);
boolean isHandshake = randomBoolean();
boolean compress = randomBoolean();
String value = "message";
threadContext.putHeader("header", "header_value");
Writeable writeable = new Message(value);
Request request = new Request(value);
boolean isRequest = randomBoolean();
if (isRequest) {
message = new OutboundMessage.Request(threadContext, new String[0], writeable, version, actionName, requestId, isHandshake,
compress);
} else {
message = new OutboundMessage.Response(threadContext, new HashSet<>(), writeable, version, requestId, isHandshake, compress);
AtomicReference<DiscoveryNode> nodeRef = new AtomicReference<>();
AtomicLong requestIdRef = new AtomicLong();
AtomicReference<String> actionRef = new AtomicReference<>();
AtomicReference<TransportRequest> requestRef = new AtomicReference<>();
handler.setMessageListener(new TransportMessageListener() {
@Override
public void onRequestSent(DiscoveryNode node, long requestId, String action, TransportRequest request,
TransportRequestOptions options) {
nodeRef.set(node);
requestIdRef.set(requestId);
actionRef.set(action);
requestRef.set(request);
}
});
handler.sendRequest(node, channel, requestId, action, request, options, version, compress, isHandshake);
AtomicBoolean isSuccess = new AtomicBoolean(false);
AtomicReference<Exception> exception = new AtomicReference<>();
ActionListener<Void> listener = ActionListener.wrap((v) -> isSuccess.set(true), exception::set);
handler.sendMessage(fakeTcpChannel, message, listener);
BytesReference reference = fakeTcpChannel.getMessageCaptor().get();
ActionListener<Void> sendListener = fakeTcpChannel.getListenerCaptor().get();
BytesReference reference = channel.getMessageCaptor().get();
ActionListener<Void> sendListener = channel.getListenerCaptor().get();
if (randomBoolean()) {
sendListener.onResponse(null);
assertTrue(isSuccess.get());
assertNull(exception.get());
} else {
IOException e = new IOException("failed");
sendListener.onFailure(e);
assertFalse(isSuccess.get());
assertSame(e, exception.get());
sendListener.onFailure(new IOException("failed"));
}
assertEquals(node, nodeRef.get());
assertEquals(requestId, requestIdRef.get());
assertEquals(action, actionRef.get());
assertEquals(request, requestRef.get());
InboundMessage.Reader reader = new InboundMessage.Reader(Version.CURRENT, namedWriteableRegistry, threadPool.getThreadContext());
try (InboundMessage inboundMessage = reader.deserialize(reference.slice(6, reference.length() - 6))) {
assertEquals(version, inboundMessage.getVersion());
assertEquals(requestId, inboundMessage.getRequestId());
if (isRequest) {
assertTrue(inboundMessage.isRequest());
assertFalse(inboundMessage.isResponse());
} else {
assertTrue(inboundMessage.isResponse());
assertFalse(inboundMessage.isRequest());
}
if (isHandshake) {
assertTrue(inboundMessage.isHandshake());
} else {
@ -147,7 +155,10 @@ public class OutboundHandlerTests extends ESTestCase {
} else {
assertFalse(inboundMessage.isCompress());
}
Message readMessage = new Message();
InboundMessage.Request inboundRequest = (InboundMessage.Request) inboundMessage;
assertThat(inboundRequest.getFeatures(), contains(feature1, feature2));
Request readMessage = new Request();
readMessage.readFrom(inboundMessage.getStreamInput());
assertEquals(value, readMessage.value);
@ -160,14 +171,163 @@ public class OutboundHandlerTests extends ESTestCase {
}
}
private static final class Message extends TransportMessage {
public void testSendResponse() throws IOException {
ThreadContext threadContext = threadPool.getThreadContext();
Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion());
String action = "handshake";
long requestId = randomLongBetween(0, 300);
boolean isHandshake = randomBoolean();
boolean compress = randomBoolean();
String value = "message";
threadContext.putHeader("header", "header_value");
Response response = new Response(value);
AtomicLong requestIdRef = new AtomicLong();
AtomicReference<String> actionRef = new AtomicReference<>();
AtomicReference<TransportResponse> responseRef = new AtomicReference<>();
handler.setMessageListener(new TransportMessageListener() {
@Override
public void onResponseSent(long requestId, String action, TransportResponse response) {
requestIdRef.set(requestId);
actionRef.set(action);
responseRef.set(response);
}
});
handler.sendResponse(version, Collections.emptySet(), channel, requestId, action, response, compress, isHandshake);
BytesReference reference = channel.getMessageCaptor().get();
ActionListener<Void> sendListener = channel.getListenerCaptor().get();
if (randomBoolean()) {
sendListener.onResponse(null);
} else {
sendListener.onFailure(new IOException("failed"));
}
assertEquals(requestId, requestIdRef.get());
assertEquals(action, actionRef.get());
assertEquals(response, responseRef.get());
InboundMessage.Reader reader = new InboundMessage.Reader(Version.CURRENT, namedWriteableRegistry, threadPool.getThreadContext());
try (InboundMessage inboundMessage = reader.deserialize(reference.slice(6, reference.length() - 6))) {
assertEquals(version, inboundMessage.getVersion());
assertEquals(requestId, inboundMessage.getRequestId());
assertFalse(inboundMessage.isRequest());
assertTrue(inboundMessage.isResponse());
if (isHandshake) {
assertTrue(inboundMessage.isHandshake());
} else {
assertFalse(inboundMessage.isHandshake());
}
if (compress) {
assertTrue(inboundMessage.isCompress());
} else {
assertFalse(inboundMessage.isCompress());
}
InboundMessage.Response inboundResponse = (InboundMessage.Response) inboundMessage;
assertFalse(inboundResponse.isError());
Response readMessage = new Response();
readMessage.readFrom(inboundMessage.getStreamInput());
assertEquals(value, readMessage.value);
try (ThreadContext.StoredContext existing = threadContext.stashContext()) {
ThreadContext.StoredContext storedContext = inboundMessage.getStoredContext();
assertNull(threadContext.getHeader("header"));
storedContext.restore();
assertEquals("header_value", threadContext.getHeader("header"));
}
}
}
public void testErrorResponse() throws IOException {
ThreadContext threadContext = threadPool.getThreadContext();
Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion());
String action = "handshake";
long requestId = randomLongBetween(0, 300);
threadContext.putHeader("header", "header_value");
ElasticsearchException error = new ElasticsearchException("boom");
AtomicLong requestIdRef = new AtomicLong();
AtomicReference<String> actionRef = new AtomicReference<>();
AtomicReference<Exception> responseRef = new AtomicReference<>();
handler.setMessageListener(new TransportMessageListener() {
@Override
public void onResponseSent(long requestId, String action, Exception error) {
requestIdRef.set(requestId);
actionRef.set(action);
responseRef.set(error);
}
});
handler.sendErrorResponse(version, Collections.emptySet(), channel, requestId, action, error);
BytesReference reference = channel.getMessageCaptor().get();
ActionListener<Void> sendListener = channel.getListenerCaptor().get();
if (randomBoolean()) {
sendListener.onResponse(null);
} else {
sendListener.onFailure(new IOException("failed"));
}
assertEquals(requestId, requestIdRef.get());
assertEquals(action, actionRef.get());
assertEquals(error, responseRef.get());
InboundMessage.Reader reader = new InboundMessage.Reader(Version.CURRENT, namedWriteableRegistry, threadPool.getThreadContext());
try (InboundMessage inboundMessage = reader.deserialize(reference.slice(6, reference.length() - 6))) {
assertEquals(version, inboundMessage.getVersion());
assertEquals(requestId, inboundMessage.getRequestId());
assertFalse(inboundMessage.isRequest());
assertTrue(inboundMessage.isResponse());
assertFalse(inboundMessage.isCompress());
assertFalse(inboundMessage.isHandshake());
InboundMessage.Response inboundResponse = (InboundMessage.Response) inboundMessage;
assertTrue(inboundResponse.isError());
RemoteTransportException remoteException = inboundMessage.getStreamInput().readException();
assertThat(remoteException.getCause(), instanceOf(ElasticsearchException.class));
assertEquals(remoteException.getCause().getMessage(), "boom");
assertEquals(action, remoteException.action());
assertEquals(channel.getLocalAddress(), remoteException.address().address());
try (ThreadContext.StoredContext existing = threadContext.stashContext()) {
ThreadContext.StoredContext storedContext = inboundMessage.getStoredContext();
assertNull(threadContext.getHeader("header"));
storedContext.restore();
assertEquals("header_value", threadContext.getHeader("header"));
}
}
}
private static final class Request extends TransportRequest {
public String value;
private Message() {
private Request() {
}
private Message(String value) {
private Request(String value) {
this.value = value;
}
@Override
public void readFrom(StreamInput in) throws IOException {
value = in.readString();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(value);
}
}
private static final class Response extends TransportResponse {
public String value;
private Response() {
}
private Response(String value) {
this.value = value;
}

View File

@ -2008,12 +2008,12 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
new NetworkService(Collections.emptyList()), PageCacheRecycler.NON_RECYCLING_INSTANCE, namedWriteableRegistry,
new NoneCircuitBreakerService()) {
@Override
protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage request, int messageLengthBytes)
protected void handleRequest(TcpChannel channel, InboundMessage.Request request, int messageLengthBytes)
throws IOException {
// we flip the isHandshake bit back and act like the handler is not found
byte status = (byte) (request.status & ~(1 << 3));
Version version = request.getVersion();
InboundMessage.RequestMessage nonHandshakeRequest = new InboundMessage.RequestMessage(request.threadContext, version,
InboundMessage.Request nonHandshakeRequest = new InboundMessage.Request(request.threadContext, version,
status, request.getRequestId(), request.getActionName(), request.getFeatures(), request.getStreamInput());
super.handleRequest(channel, nonHandshakeRequest, messageLengthBytes);
}

View File

@ -44,6 +44,10 @@ public class FakeTcpChannel implements TcpChannel {
this(isServer, "profile", new AtomicReference<>());
}
public FakeTcpChannel(boolean isServer, InetSocketAddress localAddress, InetSocketAddress remoteAddress) {
this(isServer, localAddress, remoteAddress, "profile", new AtomicReference<>());
}
public FakeTcpChannel(boolean isServer, AtomicReference<BytesReference> messageCaptor) {
this(isServer, "profile", messageCaptor);
}