From ba06c14a9778b38568e2070ae1e5b3d236a11a4e Mon Sep 17 00:00:00 2001 From: Boaz Leskes Date: Tue, 7 Feb 2017 22:11:32 +0200 Subject: [PATCH] TransportService.connectToNode should validate remote node ID (#22828) #22194 gave us the ability to open low level temporary connections to remote node based on their address. With this use case out of the way, actual full blown connections should validate the node on the other side, making sure we speak to who we think we speak to. This helps in case where multiple nodes are started on the same host and a quick node restart causes them to swap addresses, which in turn can cause confusion down the road. --- .../TransportClientNodesService.java | 191 +++++++++--------- .../cluster/node/DiscoveryNode.java | 26 ++- .../common/CheckedBiConsumer.java | 30 +++ .../discovery/zen/ZenDiscovery.java | 47 +++-- .../transport/ConnectionProfile.java | 12 ++ .../elasticsearch/transport/TcpTransport.java | 35 +++- .../elasticsearch/transport/Transport.java | 16 +- .../transport/TransportService.java | 23 ++- .../node/tasks/CancellableTasksTests.java | 18 +- .../node/tasks/TaskManagerTestCase.java | 28 ++- .../node/tasks/TransportTasksActionTests.java | 14 +- .../transport/FailAndRetryMockTransport.java | 6 +- .../TransportClientHeadersTests.java | 20 +- .../cluster/NodeConnectionsServiceTests.java | 5 +- .../discovery/ZenFaultDetectionTests.java | 68 ++++--- .../zen/PublishClusterStateActionTests.java | 7 +- .../discovery/zen/UnicastZenPingTests.java | 6 +- .../transport/ConnectionProfileTests.java | 15 ++ .../transport/TCPTransportTests.java | 35 ++++ .../transport/TransportActionProxyTests.java | 9 +- .../TransportServiceHandshakeTests.java | 18 ++ .../netty4/Netty4ScheduledPingTests.java | 9 +- .../test/ClusterServiceUtils.java | 10 +- .../test/transport/CapturingTransport.java | 5 +- .../test/transport/MockTransportService.java | 47 +++-- .../AbstractSimpleTransportTestCase.java | 171 ++++++++++++---- .../transport/MockTcpTransport.java | 3 +- .../transport/MockTcpTransportTests.java | 4 +- 28 files changed, 594 insertions(+), 284 deletions(-) create mode 100644 core/src/main/java/org/elasticsearch/common/CheckedBiConsumer.java diff --git a/core/src/main/java/org/elasticsearch/client/transport/TransportClientNodesService.java b/core/src/main/java/org/elasticsearch/client/transport/TransportClientNodesService.java index ea2906dab67..dbcf0edef28 100644 --- a/core/src/main/java/org/elasticsearch/client/transport/TransportClientNodesService.java +++ b/core/src/main/java/org/elasticsearch/client/transport/TransportClientNodesService.java @@ -22,6 +22,7 @@ package org.elasticsearch.client.transport; import com.carrotsearch.hppc.cursors.ObjectCursor; import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.util.Supplier; +import org.apache.lucene.util.IOUtils; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; @@ -38,6 +39,7 @@ import org.elasticsearch.common.component.AbstractComponent; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.common.util.concurrent.FutureUtils; import org.elasticsearch.threadpool.ThreadPool; @@ -46,6 +48,8 @@ import org.elasticsearch.transport.ConnectionProfile; import org.elasticsearch.transport.FutureTransportResponseHandler; import org.elasticsearch.transport.NodeDisconnectedException; import org.elasticsearch.transport.NodeNotConnectedException; +import org.elasticsearch.transport.PlainTransportFuture; +import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.TransportException; import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportResponseHandler; @@ -401,51 +405,37 @@ final class TransportClientNodesService extends AbstractComponent implements Clo HashSet newNodes = new HashSet<>(); HashSet newFilteredNodes = new HashSet<>(); for (DiscoveryNode listedNode : listedNodes) { - if (!transportService.nodeConnected(listedNode)) { - try { - // its a listed node, light connect to it... - logger.trace("connecting to listed node [{}]", listedNode); - transportService.connectToNode(listedNode, LISTED_NODES_PROFILE); - } catch (Exception e) { - logger.info( - (Supplier) - () -> new ParameterizedMessage("failed to connect to node [{}], removed from nodes list", listedNode), e); - hostFailureListener.onNodeDisconnected(listedNode, e); - newFilteredNodes.add(listedNode); - continue; - } - } - try { - LivenessResponse livenessResponse = transportService.submitRequest(listedNode, TransportLivenessAction.NAME, - new LivenessRequest(), - TransportRequestOptions.builder().withType(TransportRequestOptions.Type.STATE).withTimeout(pingTimeout).build(), - new FutureTransportResponseHandler() { - @Override - public LivenessResponse newInstance() { - return new LivenessResponse(); - } - }).txGet(); + try (Transport.Connection connection = transportService.openConnection(listedNode, LISTED_NODES_PROFILE)){ + final PlainTransportFuture handler = new PlainTransportFuture<>( + new FutureTransportResponseHandler() { + @Override + public LivenessResponse newInstance() { + return new LivenessResponse(); + } + }); + transportService.sendRequest(connection, TransportLivenessAction.NAME, new LivenessRequest(), + TransportRequestOptions.builder().withType(TransportRequestOptions.Type.STATE).withTimeout(pingTimeout).build(), + handler); + final LivenessResponse livenessResponse = handler.txGet(); if (!ignoreClusterName && !clusterName.equals(livenessResponse.getClusterName())) { logger.warn("node {} not part of the cluster {}, ignoring...", listedNode, clusterName); newFilteredNodes.add(listedNode); - } else if (livenessResponse.getDiscoveryNode() != null) { + } else { // use discovered information but do keep the original transport address, // so people can control which address is exactly used. DiscoveryNode nodeWithInfo = livenessResponse.getDiscoveryNode(); newNodes.add(new DiscoveryNode(nodeWithInfo.getName(), nodeWithInfo.getId(), nodeWithInfo.getEphemeralId(), nodeWithInfo.getHostName(), nodeWithInfo.getHostAddress(), listedNode.getAddress(), nodeWithInfo.getAttributes(), nodeWithInfo.getRoles(), nodeWithInfo.getVersion())); - } else { - // although we asked for one node, our target may not have completed - // initialization yet and doesn't have cluster nodes - logger.debug("node {} didn't return any discovery info, temporarily using transport discovery node", listedNode); - newNodes.add(listedNode); } + } catch (ConnectTransportException e) { + logger.debug( + (Supplier) + () -> new ParameterizedMessage("failed to connect to node [{}], ignoring...", listedNode), e); + hostFailureListener.onNodeDisconnected(listedNode, e); } catch (Exception e) { logger.info( (Supplier) () -> new ParameterizedMessage("failed to get node info for {}, disconnecting...", listedNode), e); - transportService.disconnectFromNode(listedNode); - hostFailureListener.onNodeDisconnected(listedNode, e); } } @@ -470,78 +460,91 @@ final class TransportClientNodesService extends AbstractComponent implements Clo final CountDownLatch latch = new CountDownLatch(nodesToPing.size()); final ConcurrentMap clusterStateResponses = ConcurrentCollections.newConcurrentMap(); - for (final DiscoveryNode listedNode : nodesToPing) { - threadPool.executor(ThreadPool.Names.MANAGEMENT).execute(new Runnable() { - @Override - public void run() { - try { - if (!transportService.nodeConnected(listedNode)) { - try { + try { + for (final DiscoveryNode nodeToPing : nodesToPing) { + threadPool.executor(ThreadPool.Names.MANAGEMENT).execute(new AbstractRunnable() { - // if its one of the actual nodes we will talk to, not to listed nodes, fully connect - if (nodes.contains(listedNode)) { - logger.trace("connecting to cluster node [{}]", listedNode); - transportService.connectToNode(listedNode); - } else { - // its a listed node, light connect to it... - logger.trace("connecting to listed node (light) [{}]", listedNode); - transportService.connectToNode(listedNode, LISTED_NODES_PROFILE); - } - } catch (Exception e) { - logger.debug( - (Supplier) - () -> new ParameterizedMessage("failed to connect to node [{}], ignoring...", listedNode), e); - latch.countDown(); - return; + /** + * we try to reuse existing connections but if needed we will open a temporary connection + * that will be closed at the end of the execution. + */ + Transport.Connection connectionToClose = null; + + @Override + public void onAfter() { + IOUtils.closeWhileHandlingException(connectionToClose); + } + + @Override + public void onFailure(Exception e) { + latch.countDown(); + if (e instanceof ConnectTransportException) { + logger.debug((Supplier) + () -> new ParameterizedMessage("failed to connect to node [{}], ignoring...", nodeToPing), e); + hostFailureListener.onNodeDisconnected(nodeToPing, e); + } else { + logger.info( + (Supplier) () -> new ParameterizedMessage( + "failed to get local cluster state info for {}, disconnecting...", nodeToPing), e); + } + } + + @Override + protected void doRun() throws Exception { + Transport.Connection pingConnection = null; + if (nodes.contains(nodeToPing)) { + try { + pingConnection = transportService.getConnection(nodeToPing); + } catch (NodeNotConnectedException e) { + // will use a temp connection } } - transportService.sendRequest(listedNode, ClusterStateAction.NAME, - Requests.clusterStateRequest().clear().nodes(true).local(true), - TransportRequestOptions.builder().withType(TransportRequestOptions.Type.STATE) - .withTimeout(pingTimeout).build(), - new TransportResponseHandler() { + if (pingConnection == null) { + logger.trace("connecting to cluster node [{}]", nodeToPing); + connectionToClose = transportService.openConnection(nodeToPing, LISTED_NODES_PROFILE); + pingConnection = connectionToClose; + } + transportService.sendRequest(pingConnection, ClusterStateAction.NAME, + Requests.clusterStateRequest().clear().nodes(true).local(true), + TransportRequestOptions.builder().withType(TransportRequestOptions.Type.STATE) + .withTimeout(pingTimeout).build(), + new TransportResponseHandler() { - @Override - public ClusterStateResponse newInstance() { - return new ClusterStateResponse(); + @Override + public ClusterStateResponse newInstance() { + return new ClusterStateResponse(); + } + + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + + @Override + public void handleResponse(ClusterStateResponse response) { + clusterStateResponses.put(nodeToPing, response); + latch.countDown(); + } + + @Override + public void handleException(TransportException e) { + logger.info( + (Supplier) () -> new ParameterizedMessage( + "failed to get local cluster state for {}, disconnecting...", nodeToPing), e); + try { + hostFailureListener.onNodeDisconnected(nodeToPing, e); } - - @Override - public String executor() { - return ThreadPool.Names.SAME; - } - - @Override - public void handleResponse(ClusterStateResponse response) { - clusterStateResponses.put(listedNode, response); + finally { latch.countDown(); } - - @Override - public void handleException(TransportException e) { - logger.info( - (Supplier) () -> new ParameterizedMessage( - "failed to get local cluster state for {}, disconnecting...", listedNode), e); - transportService.disconnectFromNode(listedNode); - latch.countDown(); - hostFailureListener.onNodeDisconnected(listedNode, e); - } - }); - } catch (Exception e) { - logger.info( - (Supplier)() -> new ParameterizedMessage( - "failed to get local cluster state info for {}, disconnecting...", listedNode), e); - transportService.disconnectFromNode(listedNode); - latch.countDown(); - hostFailureListener.onNodeDisconnected(listedNode, e); + } + }); } - } - }); - } - - try { + }); + } latch.await(); } catch (InterruptedException e) { + Thread.currentThread().interrupt(); return; } diff --git a/core/src/main/java/org/elasticsearch/cluster/node/DiscoveryNode.java b/core/src/main/java/org/elasticsearch/cluster/node/DiscoveryNode.java index c81161f1deb..3eea37e2c89 100644 --- a/core/src/main/java/org/elasticsearch/cluster/node/DiscoveryNode.java +++ b/core/src/main/java/org/elasticsearch/cluster/node/DiscoveryNode.java @@ -191,20 +191,26 @@ public class DiscoveryNode implements Writeable, ToXContent { /** Creates a DiscoveryNode representing the local node. */ public static DiscoveryNode createLocal(Settings settings, TransportAddress publishAddress, String nodeId) { Map attributes = new HashMap<>(Node.NODE_ATTRIBUTES.get(settings).getAsMap()); - Set roles = new HashSet<>(); - if (Node.NODE_INGEST_SETTING.get(settings)) { - roles.add(DiscoveryNode.Role.INGEST); - } - if (Node.NODE_MASTER_SETTING.get(settings)) { - roles.add(DiscoveryNode.Role.MASTER); - } - if (Node.NODE_DATA_SETTING.get(settings)) { - roles.add(DiscoveryNode.Role.DATA); - } + Set roles = getRolesFromSettings(settings); return new DiscoveryNode(Node.NODE_NAME_SETTING.get(settings), nodeId, publishAddress, attributes, roles, Version.CURRENT); } + /** extract node roles from the given settings */ + public static Set getRolesFromSettings(Settings settings) { + Set roles = new HashSet<>(); + if (Node.NODE_INGEST_SETTING.get(settings)) { + roles.add(Role.INGEST); + } + if (Node.NODE_MASTER_SETTING.get(settings)) { + roles.add(Role.MASTER); + } + if (Node.NODE_DATA_SETTING.get(settings)) { + roles.add(Role.DATA); + } + return roles; + } + /** * Creates a new {@link DiscoveryNode} by reading from the stream provided as argument * @param in the stream diff --git a/core/src/main/java/org/elasticsearch/common/CheckedBiConsumer.java b/core/src/main/java/org/elasticsearch/common/CheckedBiConsumer.java new file mode 100644 index 00000000000..3f8b76bf365 --- /dev/null +++ b/core/src/main/java/org/elasticsearch/common/CheckedBiConsumer.java @@ -0,0 +1,30 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.common; + +import java.util.function.BiConsumer; + +/** + * A {@link BiConsumer}-like interface which allows throwing checked exceptions. + */ +@FunctionalInterface +public interface CheckedBiConsumer { + void accept(T t, U u) throws E; +} diff --git a/core/src/main/java/org/elasticsearch/discovery/zen/ZenDiscovery.java b/core/src/main/java/org/elasticsearch/discovery/zen/ZenDiscovery.java index effc92a0c67..b6a023bad35 100644 --- a/core/src/main/java/org/elasticsearch/discovery/zen/ZenDiscovery.java +++ b/core/src/main/java/org/elasticsearch/discovery/zen/ZenDiscovery.java @@ -27,9 +27,9 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterName; -import org.elasticsearch.cluster.ClusterStateTaskExecutor; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterStateTaskConfig; +import org.elasticsearch.cluster.ClusterStateTaskExecutor; import org.elasticsearch.cluster.ClusterStateTaskListener; import org.elasticsearch.cluster.LocalClusterUpdateTask; import org.elasticsearch.cluster.NotMasterException; @@ -51,6 +51,7 @@ import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Setting.Property; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.discovery.Discovery; import org.elasticsearch.discovery.DiscoverySettings; import org.elasticsearch.discovery.DiscoveryStats; @@ -113,6 +114,7 @@ public class ZenDiscovery extends AbstractLifecycleComponent implements Discover private final NodesFaultDetection nodesFD; private final PublishClusterStateAction publishClusterState; private final MembershipAction membership; + private final ThreadPool threadPool; private final TimeValue pingTimeout; private final TimeValue joinTimeout; @@ -156,6 +158,7 @@ public class ZenDiscovery extends AbstractLifecycleComponent implements Discover this.joinRetryDelay = JOIN_RETRY_DELAY_SETTING.get(settings); this.maxPingsFromAnotherMaster = MAX_PINGS_FROM_ANOTHER_MASTER_SETTING.get(settings); this.sendLeaveRequest = SEND_LEAVE_REQUEST_SETTING.get(settings); + this.threadPool = threadPool; this.masterElectionIgnoreNonMasters = MASTER_ELECTION_IGNORE_NON_MASTER_PINGS_SETTING.get(settings); this.masterElectionWaitForJoinsTimeout = MASTER_ELECTION_WAIT_FOR_JOINS_TIMEOUT_SETTING.get(settings); @@ -189,7 +192,7 @@ public class ZenDiscovery extends AbstractLifecycleComponent implements Discover discoverySettings, clusterService.getClusterName()); this.membership = new MembershipAction(settings, transportService, this::localNode, new MembershipListener()); - this.joinThreadControl = new JoinThreadControl(threadPool); + this.joinThreadControl = new JoinThreadControl(); transportService.registerRequestHandler( DISCOVERY_REJOIN_ACTION_NAME, RejoinClusterRequest::new, ThreadPool.Names.SAME, new RejoinClusterRequestHandler()); @@ -968,21 +971,28 @@ public class ZenDiscovery extends AbstractLifecycleComponent implements Discover return rejoin(localClusterState, "zen-disco-discovered another master with a new cluster_state [" + otherMaster + "][" + reason + "]"); } else { logger.warn("discovered [{}] which is also master but with an older cluster_state, telling [{}] to rejoin the cluster ([{}])", otherMaster, otherMaster, reason); - try { - // make sure we're connected to this node (connect to node does nothing if we're already connected) - // since the network connections are asymmetric, it may be that we received a state but have disconnected from the node - // in the past (after a master failure, for example) - transportService.connectToNode(otherMaster); - transportService.sendRequest(otherMaster, DISCOVERY_REJOIN_ACTION_NAME, new RejoinClusterRequest(localClusterState.nodes().getLocalNodeId()), new EmptyTransportResponseHandler(ThreadPool.Names.SAME) { + // spawn to a background thread to not do blocking operations on the cluster state thread + threadPool.generic().execute(new AbstractRunnable() { + @Override + public void onFailure(Exception e) { + logger.warn((Supplier) () -> new ParameterizedMessage("failed to send rejoin request to [{}]", otherMaster), e); + } - @Override - public void handleException(TransportException exp) { - logger.warn((Supplier) () -> new ParameterizedMessage("failed to send rejoin request to [{}]", otherMaster), exp); - } - }); - } catch (Exception e) { - logger.warn((Supplier) () -> new ParameterizedMessage("failed to send rejoin request to [{}]", otherMaster), e); - } + @Override + protected void doRun() throws Exception { + // make sure we're connected to this node (connect to node does nothing if we're already connected) + // since the network connections are asymmetric, it may be that we received a state but have disconnected from the node + // in the past (after a master failure, for example) + transportService.connectToNode(otherMaster); + transportService.sendRequest(otherMaster, DISCOVERY_REJOIN_ACTION_NAME, new RejoinClusterRequest(localNode().getId()), new EmptyTransportResponseHandler(ThreadPool.Names.SAME) { + + @Override + public void handleException(TransportException exp) { + logger.warn((Supplier) () -> new ParameterizedMessage("failed to send rejoin request to [{}]", otherMaster), exp); + } + }); + } + }); return LocalClusterUpdateTask.unchanged(); } } @@ -1132,14 +1142,9 @@ public class ZenDiscovery extends AbstractLifecycleComponent implements Discover */ private class JoinThreadControl { - private final ThreadPool threadPool; private final AtomicBoolean running = new AtomicBoolean(false); private final AtomicReference currentJoinThread = new AtomicReference<>(); - JoinThreadControl(ThreadPool threadPool) { - this.threadPool = threadPool; - } - /** returns true if join thread control is started and there is currently an active join thread */ public boolean joinThreadActive() { Thread currentThread = currentJoinThread.get(); diff --git a/core/src/main/java/org/elasticsearch/transport/ConnectionProfile.java b/core/src/main/java/org/elasticsearch/transport/ConnectionProfile.java index 2dc605cf3d4..17f3f7b7b4a 100644 --- a/core/src/main/java/org/elasticsearch/transport/ConnectionProfile.java +++ b/core/src/main/java/org/elasticsearch/transport/ConnectionProfile.java @@ -79,6 +79,18 @@ public final class ConnectionProfile { private TimeValue connectTimeout; private TimeValue handshakeTimeout; + /** create an empty builder */ + public Builder() { + } + + /** copy constructor, using another profile as a base */ + public Builder(ConnectionProfile source) { + handles.addAll(source.getHandles()); + offset = source.getNumConnections(); + handles.forEach(th -> addedTypes.addAll(th.types)); + connectTimeout = source.getConnectTimeout(); + handshakeTimeout = source.getHandshakeTimeout(); + } /** * Sets a connect timeout for this connection profile */ diff --git a/core/src/main/java/org/elasticsearch/transport/TcpTransport.java b/core/src/main/java/org/elasticsearch/transport/TcpTransport.java index 5cfb5f7a423..79b4ff0f9f7 100644 --- a/core/src/main/java/org/elasticsearch/transport/TcpTransport.java +++ b/core/src/main/java/org/elasticsearch/transport/TcpTransport.java @@ -27,6 +27,8 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.Version; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.CheckedBiConsumer; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Strings; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.bytes.BytesArray; @@ -198,6 +200,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i int connectionsPerNodePing = CONNECTIONS_PER_NODE_PING.get(settings); ConnectionProfile.Builder builder = new ConnectionProfile.Builder(); builder.setConnectTimeout(TCP_CONNECT_TIMEOUT.get(settings)); + builder.setHandshakeTimeout(TCP_CONNECT_TIMEOUT.get(settings)); builder.addConnections(connectionsPerNodeBulk, TransportRequestOptions.Type.BULK); builder.addConnections(connectionsPerNodePing, TransportRequestOptions.Type.PING); // if we are not master eligible we don't need a dedicated channel to publish the state @@ -422,8 +425,10 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i } @Override - public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) { - connectionProfile = connectionProfile == null ? defaultConnectionProfile : connectionProfile; + public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, + CheckedBiConsumer connectionValidator) + throws ConnectTransportException { + connectionProfile = resolveConnectionProfile(connectionProfile, defaultConnectionProfile); if (node == null) { throw new ConnectTransportException(null, "can't connect to a null node"); } @@ -438,10 +443,12 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i try { try { nodeChannels = openConnection(node, connectionProfile); + connectionValidator.accept(nodeChannels, connectionProfile); } catch (Exception e) { logger.trace( (Supplier) () -> new ParameterizedMessage( "failed to connect to [{}], cleaning dangling connections", node), e); + IOUtils.closeWhileHandlingException(nodeChannels); throw e; } // we acquire a connection lock, so no way there is an existing connection @@ -461,6 +468,29 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i } } + /** + * takes a {@link ConnectionProfile} that have been passed as a parameter to the public methods + * and resolves it to a fully specified (i.e., no nulls) profile + */ + static ConnectionProfile resolveConnectionProfile(@Nullable ConnectionProfile connectionProfile, + ConnectionProfile defaultConnectionProfile) { + Objects.requireNonNull(defaultConnectionProfile); + if (connectionProfile == null) { + return defaultConnectionProfile; + } else if (connectionProfile.getConnectTimeout() != null && connectionProfile.getHandshakeTimeout() != null) { + return connectionProfile; + } else { + ConnectionProfile.Builder builder = new ConnectionProfile.Builder(connectionProfile); + if (connectionProfile.getConnectTimeout() == null) { + builder.setConnectTimeout(defaultConnectionProfile.getConnectTimeout()); + } + if (connectionProfile.getHandshakeTimeout() == null) { + builder.setHandshakeTimeout(defaultConnectionProfile.getHandshakeTimeout()); + } + return builder.build(); + } + } + @Override public final NodeChannels openConnection(DiscoveryNode node, ConnectionProfile connectionProfile) throws IOException { if (node == null) { @@ -468,6 +498,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i } boolean success = false; NodeChannels nodeChannels = null; + connectionProfile = resolveConnectionProfile(connectionProfile, defaultConnectionProfile); globalLock.readLock().lock(); // ensure we don't open connections while we are closing try { ensureOpen(); diff --git a/core/src/main/java/org/elasticsearch/transport/Transport.java b/core/src/main/java/org/elasticsearch/transport/Transport.java index 44c72e1f548..350251b807a 100644 --- a/core/src/main/java/org/elasticsearch/transport/Transport.java +++ b/core/src/main/java/org/elasticsearch/transport/Transport.java @@ -21,6 +21,7 @@ package org.elasticsearch.transport; import org.elasticsearch.Version; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.component.LifecycleComponent; @@ -63,9 +64,11 @@ public interface Transport extends LifecycleComponent { boolean nodeConnected(DiscoveryNode node); /** - * Connects to a node with the given connection profile. If the node is already connected this method has no effect + * Connects to a node with the given connection profile. If the node is already connected this method has no effect. + * Once a successful is established, it can be validated before being exposed. */ - void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) throws ConnectTransportException; + void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, + CheckedBiConsumer connectionValidator) throws ConnectTransportException; /** * Disconnected from the given node, if not connected, will do nothing. @@ -94,15 +97,16 @@ public interface Transport extends LifecycleComponent { * implementation. * * @throws NodeNotConnectedException if the node is not connected - * @see #connectToNode(DiscoveryNode, ConnectionProfile) + * @see #connectToNode(DiscoveryNode, ConnectionProfile, CheckedBiConsumer) */ Connection getConnection(DiscoveryNode node); /** - * Opens a new connection to the given node and returns it. In contrast to {@link #connectToNode(DiscoveryNode, ConnectionProfile)} - * the returned connection is not managed by the transport implementation. This connection must be closed once it's not needed anymore. + * Opens a new connection to the given node and returns it. In contrast to + * {@link #connectToNode(DiscoveryNode, ConnectionProfile, CheckedBiConsumer)} the returned connection is not managed by + * the transport implementation. This connection must be closed once it's not needed anymore. * This connection type can be used to execute a handshake between two nodes before the node will be published via - * {@link #connectToNode(DiscoveryNode, ConnectionProfile)}. + * {@link #connectToNode(DiscoveryNode, ConnectionProfile, CheckedBiConsumer)}. */ Connection openConnection(DiscoveryNode node, ConnectionProfile profile) throws IOException; diff --git a/core/src/main/java/org/elasticsearch/transport/TransportService.java b/core/src/main/java/org/elasticsearch/transport/TransportService.java index 5e26d0a4b37..a974a932d42 100644 --- a/core/src/main/java/org/elasticsearch/transport/TransportService.java +++ b/core/src/main/java/org/elasticsearch/transport/TransportService.java @@ -73,7 +73,7 @@ import static org.elasticsearch.common.settings.Setting.listSetting; public class TransportService extends AbstractLifecycleComponent { public static final String DIRECT_RESPONSE_PROFILE = ".direct"; - private static final String HANDSHAKE_ACTION_NAME = "internal:transport/handshake"; + public static final String HANDSHAKE_ACTION_NAME = "internal:transport/handshake"; private final CountDownLatch blockIncomingRequestsLatch = new CountDownLatch(1); protected final Transport transport; @@ -130,7 +130,7 @@ public class TransportService extends AbstractLifecycleComponent { @Override public void sendRequest(long requestId, String action, TransportRequest request, TransportRequestOptions options) throws IOException, TransportException { - sendLocalRequest(requestId, action, request); + sendLocalRequest(requestId, action, request, options); } @Override @@ -206,6 +206,7 @@ public class TransportService extends AbstractLifecycleComponent { HANDSHAKE_ACTION_NAME, () -> HandshakeRequest.INSTANCE, ThreadPool.Names.SAME, + false, false, (request, channel) -> channel.sendResponse( new HandshakeResponse(localNode, clusterName, localNode.getVersion()))); } @@ -307,7 +308,13 @@ public class TransportService extends AbstractLifecycleComponent { if (isLocalNode(node)) { return; } - transport.connectToNode(node, connectionProfile); + transport.connectToNode(node, connectionProfile, (newConnection, actualProfile) -> { + // We don't validate cluster names to allow for tribe node connections. + final DiscoveryNode remote = handshake(newConnection, actualProfile.getHandshakeTimeout().millis(), cn -> true); + if (node.equals(remote) == false) { + throw new ConnectTransportException(node, "handshake failed. unexpected remote node " + remote); + } + }); } /** @@ -393,7 +400,7 @@ public class TransportService extends AbstractLifecycleComponent { } - static class HandshakeResponse extends TransportResponse { + public static class HandshakeResponse extends TransportResponse { private DiscoveryNode discoveryNode; private ClusterName clusterName; private Version version; @@ -401,7 +408,7 @@ public class TransportService extends AbstractLifecycleComponent { HandshakeResponse() { } - HandshakeResponse(DiscoveryNode discoveryNode, ClusterName clusterName, Version version) { + public HandshakeResponse(DiscoveryNode discoveryNode, ClusterName clusterName, Version version) { this.discoveryNode = discoveryNode; this.version = version; this.clusterName = clusterName; @@ -595,9 +602,11 @@ public class TransportService extends AbstractLifecycleComponent { } } - private void sendLocalRequest(long requestId, final String action, final TransportRequest request) { + private void sendLocalRequest(long requestId, final String action, final TransportRequest request, TransportRequestOptions options) { final DirectResponseChannel channel = new DirectResponseChannel(logger, localNode, action, requestId, adapter, threadPool); try { + adapter.onRequestSent(localNode, requestId, action, request, options); + adapter.onRequestReceived(requestId, action); final RequestHandlerRegistry reg = adapter.getRequestHandler(action); if (reg == null) { throw new ActionNotFoundTransportException("Action [" + action + "] not found"); @@ -1076,6 +1085,7 @@ public class TransportService extends AbstractLifecycleComponent { @Override public void sendResponse(final TransportResponse response, TransportResponseOptions options) throws IOException { + adapter.onResponseSent(requestId, action, response, options); final TransportResponseHandler handler = adapter.onResponseReceived(requestId); // ignore if its null, the adapter logs it if (handler != null) { @@ -1099,6 +1109,7 @@ public class TransportService extends AbstractLifecycleComponent { @Override public void sendResponse(Exception exception) throws IOException { + adapter.onResponseSent(requestId, action, exception); final TransportResponseHandler handler = adapter.onResponseReceived(requestId); // ignore if its null, the adapter logs it if (handler != null) { diff --git a/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksTests.java b/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksTests.java index decff2ffc37..c28fddf68ad 100644 --- a/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksTests.java +++ b/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksTests.java @@ -212,7 +212,7 @@ public class CancellableTasksTests extends TaskManagerTestCase { CancellableTestNodesAction[] actions = new CancellableTestNodesAction[nodesCount]; for (int i = 0; i < testNodes.length; i++) { boolean shouldBlock = blockOnNodes.contains(testNodes[i]); - logger.info("The action in the node [{}] should block: [{}]", testNodes[i].discoveryNode.getId(), shouldBlock); + logger.info("The action in the node [{}] should block: [{}]", testNodes[i].getNodeId(), shouldBlock); actions[i] = new CancellableTestNodesAction(CLUSTER_SETTINGS, "testAction", threadPool, testNodes[i] .clusterService, testNodes[i].transportService, shouldBlock, actionLatch); } @@ -251,7 +251,7 @@ public class CancellableTasksTests extends TaskManagerTestCase { // Cancel main task CancelTasksRequest request = new CancelTasksRequest(); request.setReason("Testing Cancellation"); - request.setTaskId(new TaskId(testNodes[0].discoveryNode.getId(), mainTask.getId())); + request.setTaskId(new TaskId(testNodes[0].getNodeId(), mainTask.getId())); // And send the cancellation request to a random node CancelTasksResponse response = testNodes[randomIntBetween(0, testNodes.length - 1)].transportCancelTasksAction.execute(request) .get(); @@ -288,7 +288,7 @@ public class CancellableTasksTests extends TaskManagerTestCase { // Make sure that tasks are no longer running ListTasksResponse listTasksResponse = testNodes[randomIntBetween(0, testNodes.length - 1)] .transportListTasksAction.execute(new ListTasksRequest().setTaskId( - new TaskId(testNodes[0].discoveryNode.getId(), mainTask.getId()))).get(); + new TaskId(testNodes[0].getNodeId(), mainTask.getId()))).get(); assertEquals(0, listTasksResponse.getTasks().size()); // Make sure that there are no leftover bans, the ban removal is async, so we might return from the cancellation @@ -323,7 +323,7 @@ public class CancellableTasksTests extends TaskManagerTestCase { // Cancel all child tasks without cancelling the main task, which should quit on its own CancelTasksRequest request = new CancelTasksRequest(); request.setReason("Testing Cancellation"); - request.setParentTaskId(new TaskId(testNodes[0].discoveryNode.getId(), mainTask.getId())); + request.setParentTaskId(new TaskId(testNodes[0].getNodeId(), mainTask.getId())); // And send the cancellation request to a random node CancelTasksResponse response = testNodes[randomIntBetween(1, testNodes.length - 1)].transportCancelTasksAction.execute(request) .get(); @@ -339,7 +339,7 @@ public class CancellableTasksTests extends TaskManagerTestCase { // Make sure that main task is no longer running ListTasksResponse listTasksResponse = testNodes[randomIntBetween(0, testNodes.length - 1)] .transportListTasksAction.execute(new ListTasksRequest().setTaskId( - new TaskId(testNodes[0].discoveryNode.getId(), mainTask.getId()))).get(); + new TaskId(testNodes[0].getNodeId(), mainTask.getId()))).get(); assertEquals(0, listTasksResponse.getTasks().size()); } catch (ExecutionException | InterruptedException ex) { @@ -374,7 +374,7 @@ public class CancellableTasksTests extends TaskManagerTestCase { } }); - String mainNode = testNodes[0].discoveryNode.getId(); + String mainNode = testNodes[0].getNodeId(); // Make sure that tasks are running ListTasksResponse listTasksResponse = testNodes[randomIntBetween(0, testNodes.length - 1)] @@ -384,12 +384,12 @@ public class CancellableTasksTests extends TaskManagerTestCase { // Simulate the coordinating node leaving the cluster DiscoveryNode[] discoveryNodes = new DiscoveryNode[testNodes.length - 1]; for (int i = 1; i < testNodes.length; i++) { - discoveryNodes[i - 1] = testNodes[i].discoveryNode; + discoveryNodes[i - 1] = testNodes[i].discoveryNode(); } DiscoveryNode master = discoveryNodes[0]; for (int i = 1; i < testNodes.length; i++) { // Notify only nodes that should remain in the cluster - setState(testNodes[i].clusterService, ClusterStateCreationUtils.state(testNodes[i].discoveryNode, master, discoveryNodes)); + setState(testNodes[i].clusterService, ClusterStateCreationUtils.state(testNodes[i].discoveryNode(), master, discoveryNodes)); } if (simulateBanBeforeLeaving) { @@ -397,7 +397,7 @@ public class CancellableTasksTests extends TaskManagerTestCase { // Simulate issuing cancel request on the node that is about to leave the cluster CancelTasksRequest request = new CancelTasksRequest(); request.setReason("Testing Cancellation"); - request.setTaskId(new TaskId(testNodes[0].discoveryNode.getId(), mainTask.getId())); + request.setTaskId(new TaskId(testNodes[0].getNodeId(), mainTask.getId())); // And send the cancellation request to a random node CancelTasksResponse response = testNodes[0].transportCancelTasksAction.execute(request).get(); logger.info("--> Done simulating issuing cancel request on the node that is about to leave the cluster"); diff --git a/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java b/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java index 7366ce22523..0cece76425d 100644 --- a/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java +++ b/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java @@ -18,6 +18,7 @@ */ package org.elasticsearch.action.admin.cluster.node.tasks; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.Version; import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.admin.cluster.node.tasks.cancel.TransportCancelTasksAction; @@ -40,6 +41,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.transport.BoundTransportAddress; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.tasks.TaskManager; @@ -58,6 +60,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.concurrent.TimeUnit; +import java.util.function.Function; import java.util.function.Supplier; import static java.util.Collections.emptyMap; @@ -169,12 +172,16 @@ public abstract class TaskManagerTestCase extends ESTestCase { public static class TestNode implements Releasable { public TestNode(String name, ThreadPool threadPool, Settings settings) { - clusterService = createClusterService(threadPool); + final Function boundTransportAddressDiscoveryNodeFunction = + address -> { + discoveryNode.set(new DiscoveryNode(name, address.publishAddress(), emptyMap(), emptySet(), Version.CURRENT)); + return discoveryNode.get(); + }; transportService = new TransportService(settings, new MockTcpTransport(settings, threadPool, BigArrays.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService(), new NamedWriteableRegistry(ClusterModule.getNamedWriteables()), new NetworkService(settings, Collections.emptyList())), - threadPool, TransportService.NOOP_TRANSPORT_INTERCEPTOR, x -> clusterService.localNode(), null) { + threadPool, TransportService.NOOP_TRANSPORT_INTERCEPTOR, boundTransportAddressDiscoveryNodeFunction, null) { @Override protected TaskManager createTaskManager() { if (MockTaskManager.USE_MOCK_TASK_MANAGER_SETTING.get(settings)) { @@ -185,9 +192,8 @@ public abstract class TaskManagerTestCase extends ESTestCase { } }; transportService.start(); + clusterService = createClusterService(threadPool, discoveryNode.get()); clusterService.addStateApplier(transportService.getTaskManager()); - discoveryNode = new DiscoveryNode(name, transportService.boundAddress().publishAddress(), - emptyMap(), emptySet(), Version.CURRENT); IndexNameExpressionResolver indexNameExpressionResolver = new IndexNameExpressionResolver(settings); ActionFilters actionFilters = new ActionFilters(emptySet()); transportListTasksAction = new TransportListTasksAction(settings, threadPool, clusterService, transportService, @@ -199,7 +205,7 @@ public abstract class TaskManagerTestCase extends ESTestCase { public final ClusterService clusterService; public final TransportService transportService; - public final DiscoveryNode discoveryNode; + private final SetOnce discoveryNode = new SetOnce<>(); public final TransportListTasksAction transportListTasksAction; public final TransportCancelTasksAction transportCancelTasksAction; @@ -210,22 +216,24 @@ public abstract class TaskManagerTestCase extends ESTestCase { } public String getNodeId() { - return discoveryNode.getId(); + return discoveryNode().getId(); } + + public DiscoveryNode discoveryNode() { return discoveryNode.get(); } } public static void connectNodes(TestNode... nodes) { DiscoveryNode[] discoveryNodes = new DiscoveryNode[nodes.length]; for (int i = 0; i < nodes.length; i++) { - discoveryNodes[i] = nodes[i].discoveryNode; + discoveryNodes[i] = nodes[i].discoveryNode(); } DiscoveryNode master = discoveryNodes[0]; for (TestNode node : nodes) { - setState(node.clusterService, ClusterStateCreationUtils.state(node.discoveryNode, master, discoveryNodes)); + setState(node.clusterService, ClusterStateCreationUtils.state(node.discoveryNode(), master, discoveryNodes)); } for (TestNode nodeA : nodes) { for (TestNode nodeB : nodes) { - nodeA.transportService.connectToNode(nodeB.discoveryNode); + nodeA.transportService.connectToNode(nodeB.discoveryNode()); } } } @@ -233,7 +241,7 @@ public abstract class TaskManagerTestCase extends ESTestCase { public static RecordingTaskManagerListener[] setupListeners(TestNode[] nodes, String... actionMasks) { RecordingTaskManagerListener[] listeners = new RecordingTaskManagerListener[nodes.length]; for (int i = 0; i < nodes.length; i++) { - listeners[i] = new RecordingTaskManagerListener(nodes[i].discoveryNode.getId(), actionMasks); + listeners[i] = new RecordingTaskManagerListener(nodes[i].getNodeId(), actionMasks); ((MockTaskManager) (nodes[i].transportService.getTaskManager())).addListener(listeners[i]); } return listeners; diff --git a/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java b/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java index 07859070d10..4e624164fa0 100644 --- a/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java +++ b/core/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java @@ -310,7 +310,7 @@ public class TransportTasksActionTests extends TaskManagerTestCase { Thread.currentThread().interrupt(); } logger.info("Action on node {} finished", node); - return new NodeResponse(testNodes[node].discoveryNode); + return new NodeResponse(testNodes[node].discoveryNode()); } }; } @@ -370,10 +370,10 @@ public class TransportTasksActionTests extends TaskManagerTestCase { assertEquals(testNodes.length, response.getPerNodeTasks().size()); // Coordinating node - assertEquals(2, response.getPerNodeTasks().get(testNodes[0].discoveryNode.getId()).size()); + assertEquals(2, response.getPerNodeTasks().get(testNodes[0].getNodeId()).size()); // Other nodes node for (int i = 1; i < testNodes.length; i++) { - assertEquals(1, response.getPerNodeTasks().get(testNodes[i].discoveryNode.getId()).size()); + assertEquals(1, response.getPerNodeTasks().get(testNodes[i].getNodeId()).size()); } // There should be a single main task when grouped by tasks assertEquals(1, response.getTaskGroups().size()); @@ -535,7 +535,7 @@ public class TransportTasksActionTests extends TaskManagerTestCase { // Try to cancel main task using action name CancelTasksRequest request = new CancelTasksRequest(); - request.setNodes(testNodes[0].discoveryNode.getId()); + request.setNodes(testNodes[0].getNodeId()); request.setReason("Testing Cancellation"); request.setActions(actionName); CancelTasksResponse response = testNodes[randomIntBetween(0, testNodes.length - 1)].transportCancelTasksAction.execute(request) @@ -550,7 +550,7 @@ public class TransportTasksActionTests extends TaskManagerTestCase { // Try to cancel main task using id request = new CancelTasksRequest(); request.setReason("Testing Cancellation"); - request.setTaskId(new TaskId(testNodes[0].discoveryNode.getId(), task.getId())); + request.setTaskId(new TaskId(testNodes[0].getNodeId(), task.getId())); response = testNodes[randomIntBetween(0, testNodes.length - 1)].transportCancelTasksAction.execute(request).get(); // Shouldn't match any tasks since testAction doesn't support cancellation @@ -766,11 +766,11 @@ public class TransportTasksActionTests extends TaskManagerTestCase { byNodes = (Map) byNodes.get("nodes"); // One element on the top level assertEquals(testNodes.length, byNodes.size()); - Map firstNode = (Map) byNodes.get(testNodes[0].discoveryNode.getId()); + Map firstNode = (Map) byNodes.get(testNodes[0].getNodeId()); firstNode = (Map) firstNode.get("tasks"); assertEquals(2, firstNode.size()); // two tasks for the first node for (int i = 1; i < testNodes.length; i++) { - Map otherNode = (Map) byNodes.get(testNodes[i].discoveryNode.getId()); + Map otherNode = (Map) byNodes.get(testNodes[i].getNodeId()); otherNode = (Map) otherNode.get("tasks"); assertEquals(1, otherNode.size()); // one tasks for the all other nodes } diff --git a/core/src/test/java/org/elasticsearch/client/transport/FailAndRetryMockTransport.java b/core/src/test/java/org/elasticsearch/client/transport/FailAndRetryMockTransport.java index de75b920ce3..e77f34d2eb8 100644 --- a/core/src/test/java/org/elasticsearch/client/transport/FailAndRetryMockTransport.java +++ b/core/src/test/java/org/elasticsearch/client/transport/FailAndRetryMockTransport.java @@ -26,6 +26,7 @@ import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.component.Lifecycle; import org.elasticsearch.common.component.LifecycleListener; import org.elasticsearch.common.settings.Settings; @@ -48,7 +49,6 @@ import java.util.Map; import java.util.Random; import java.util.Set; import java.util.concurrent.CopyOnWriteArraySet; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; @@ -182,7 +182,9 @@ abstract class FailAndRetryMockTransport imp } @Override - public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) throws ConnectTransportException { + public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, + CheckedBiConsumer connectionValidator) + throws ConnectTransportException { } diff --git a/core/src/test/java/org/elasticsearch/client/transport/TransportClientHeadersTests.java b/core/src/test/java/org/elasticsearch/client/transport/TransportClientHeadersTests.java index c7c98aaa99d..e3605bf552d 100644 --- a/core/src/test/java/org/elasticsearch/client/transport/TransportClientHeadersTests.java +++ b/core/src/test/java/org/elasticsearch/client/transport/TransportClientHeadersTests.java @@ -50,14 +50,13 @@ import org.elasticsearch.transport.TransportRequestHandler; import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportResponse; import org.elasticsearch.transport.TransportResponseHandler; +import org.elasticsearch.transport.TransportService; import java.util.Collections; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import static org.hamcrest.Matchers.is; - public class TransportClientHeadersTests extends AbstractClientHeadersTestCase { private MockTransportService transportService; @@ -157,15 +156,14 @@ public class TransportClientHeadersTests extends AbstractClientHeadersTestCase { TransportRequest request, TransportRequestOptions options, TransportResponseHandler handler) { + final ClusterName clusterName = new ClusterName("cluster1"); if (TransportLivenessAction.NAME.equals(action)) { assertHeaders(threadPool); ((TransportResponseHandler) handler).handleResponse( - new LivenessResponse(new ClusterName("cluster1"), connection.getNode())); - return; - } - if (ClusterStateAction.NAME.equals(action)) { + new LivenessResponse(clusterName, connection.getNode())); + } else if (ClusterStateAction.NAME.equals(action)) { assertHeaders(threadPool); - ClusterName cluster1 = new ClusterName("cluster1"); + ClusterName cluster1 = clusterName; ClusterState.Builder builder = ClusterState.builder(cluster1); //the sniffer detects only data nodes builder.nodes(DiscoveryNodes.builder().add(new DiscoveryNode("node_id", "someId", "some_ephemeralId_id", @@ -174,10 +172,12 @@ public class TransportClientHeadersTests extends AbstractClientHeadersTestCase { ((TransportResponseHandler) handler) .handleResponse(new ClusterStateResponse(cluster1, builder.build())); clusterStateLatch.countDown(); - return; + } else if (TransportService.HANDSHAKE_ACTION_NAME .equals(action)) { + ((TransportResponseHandler) handler).handleResponse( + new TransportService.HandshakeResponse(connection.getNode(), clusterName, connection.getNode().getVersion())); + } else { + handler.handleException(new TransportException("", new InternalException(action))); } - - handler.handleException(new TransportException("", new InternalException(action))); } }; } diff --git a/core/src/test/java/org/elasticsearch/cluster/NodeConnectionsServiceTests.java b/core/src/test/java/org/elasticsearch/cluster/NodeConnectionsServiceTests.java index 6efed2638b2..a1b80803e0c 100644 --- a/core/src/test/java/org/elasticsearch/cluster/NodeConnectionsServiceTests.java +++ b/core/src/test/java/org/elasticsearch/cluster/NodeConnectionsServiceTests.java @@ -22,6 +22,7 @@ package org.elasticsearch.cluster; import org.elasticsearch.Version; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.component.Lifecycle; import org.elasticsearch.common.component.LifecycleListener; @@ -198,7 +199,9 @@ public class NodeConnectionsServiceTests extends ESTestCase { } @Override - public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) throws ConnectTransportException { + public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, + CheckedBiConsumer connectionValidator) + throws ConnectTransportException { if (connectionProfile == null) { if (connectedNodes.contains(node) == false && randomConnectionExceptions && randomBoolean()) { throw new ConnectTransportException(node, "simulated"); diff --git a/core/src/test/java/org/elasticsearch/discovery/ZenFaultDetectionTests.java b/core/src/test/java/org/elasticsearch/discovery/ZenFaultDetectionTests.java index 49af88f5331..b24c5c367b4 100644 --- a/core/src/test/java/org/elasticsearch/discovery/ZenFaultDetectionTests.java +++ b/core/src/test/java/org/elasticsearch/discovery/ZenFaultDetectionTests.java @@ -39,6 +39,7 @@ import org.elasticsearch.discovery.zen.MasterFaultDetection; import org.elasticsearch.discovery.zen.NodesFaultDetection; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.indices.breaker.HierarchyCircuitBreakerService; +import org.elasticsearch.node.Node; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.threadpool.TestThreadPool; @@ -58,8 +59,6 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import static java.util.Collections.emptyMap; -import static java.util.Collections.emptySet; import static java.util.Collections.singleton; import static org.elasticsearch.test.ClusterServiceUtils.createClusterService; import static org.elasticsearch.test.ClusterServiceUtils.setState; @@ -75,10 +74,12 @@ public class ZenFaultDetectionTests extends ESTestCase { protected static final Version version0 = Version.fromId(/*0*/99); protected DiscoveryNode nodeA; protected MockTransportService serviceA; + private Settings settingsA; protected static final Version version1 = Version.fromId(199); protected DiscoveryNode nodeB; protected MockTransportService serviceB; + private Settings settingsB; @Override @Before @@ -89,17 +90,19 @@ public class ZenFaultDetectionTests extends ESTestCase { .build(); ClusterSettings clusterSettings = new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); threadPool = new TestThreadPool(getClass().getName()); - clusterServiceA = createClusterService(threadPool); - clusterServiceB = createClusterService(threadPool); circuitBreakerService = new HierarchyCircuitBreakerService(settings, clusterSettings); - serviceA = build(Settings.builder().put("name", "TS_A").build(), version0); - nodeA = new DiscoveryNode("TS_A", "TS_A", serviceA.boundAddress().publishAddress(), emptyMap(), emptySet(), version0); - serviceB = build(Settings.builder().put("name", "TS_B").build(), version1); - nodeB = new DiscoveryNode("TS_B", "TS_B", serviceB.boundAddress().publishAddress(), emptyMap(), emptySet(), version1); + settingsA = Settings.builder().put("node.name", "TS_A").put(settings).build(); + serviceA = build(settingsA, version0); + nodeA = serviceA.getLocalDiscoNode(); + settingsB = Settings.builder().put("node.name", "TS_B").put(settings).build(); + serviceB = build(settingsB, version1); + nodeB = serviceB.getLocalDiscoNode(); + clusterServiceA = createClusterService(settingsA, threadPool, nodeA); + clusterServiceB = createClusterService(settingsB, threadPool, nodeB); // wait till all nodes are properly connected and the event has been sent, so tests in this class // will not get this callback called on the connections done in this setup - final CountDownLatch latch = new CountDownLatch(4); + final CountDownLatch latch = new CountDownLatch(2); TransportConnectionListener waitForConnection = new TransportConnectionListener() { @Override public void onNodeConnected(DiscoveryNode node) { @@ -138,14 +141,20 @@ public class ZenFaultDetectionTests extends ESTestCase { protected MockTransportService build(Settings settings, Version version) { NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); MockTransportService transportService = - new MockTransportService( - Settings.builder() - // trace zenfd actions but keep the default otherwise - .put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), singleton(TransportLivenessAction.NAME)) - .build(), - new MockTcpTransport(settings, threadPool, BigArrays.NON_RECYCLING_INSTANCE, circuitBreakerService, - namedWriteableRegistry, new NetworkService(settings, Collections.emptyList()), version), - threadPool, TransportService.NOOP_TRANSPORT_INTERCEPTOR, null); + new MockTransportService( + Settings.builder() + .put(settings) + // trace zenfd actions but keep the default otherwise + .put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), singleton(TransportLivenessAction.NAME)) + .build(), + new MockTcpTransport(settings, threadPool, BigArrays.NON_RECYCLING_INSTANCE, circuitBreakerService, + namedWriteableRegistry, new NetworkService(settings, Collections.emptyList()), version), + threadPool, + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + (boundAddress) -> + new DiscoveryNode(Node.NODE_NAME_SETTING.get(settings), boundAddress.publishAddress(), + Node.NODE_ATTRIBUTES.get(settings).getAsMap(), DiscoveryNode.getRolesFromSettings(settings), version), + null); transportService.start(); transportService.acceptIncomingRequests(); return transportService; @@ -170,15 +179,17 @@ public class ZenFaultDetectionTests extends ESTestCase { } public void testNodesFaultDetectionConnectOnDisconnect() throws InterruptedException { - Settings.Builder settings = Settings.builder(); boolean shouldRetry = randomBoolean(); // make sure we don't ping again after the initial ping - settings.put(FaultDetection.CONNECT_ON_NETWORK_DISCONNECT_SETTING.getKey(), shouldRetry) - .put(FaultDetection.PING_INTERVAL_SETTING.getKey(), "5m"); + final Settings pingSettings = Settings.builder() + .put(FaultDetection.CONNECT_ON_NETWORK_DISCONNECT_SETTING.getKey(), shouldRetry) + .put(FaultDetection.PING_INTERVAL_SETTING.getKey(), "5m").build(); ClusterState clusterState = ClusterState.builder(new ClusterName("test")).nodes(buildNodesForA(true)).build(); - NodesFaultDetection nodesFDA = new NodesFaultDetection(settings.build(), threadPool, serviceA, clusterState.getClusterName()); + NodesFaultDetection nodesFDA = new NodesFaultDetection(Settings.builder().put(settingsA).put(pingSettings).build(), + threadPool, serviceA, clusterState.getClusterName()); nodesFDA.setLocalNode(nodeA); - NodesFaultDetection nodesFDB = new NodesFaultDetection(settings.build(), threadPool, serviceB, clusterState.getClusterName()); + NodesFaultDetection nodesFDB = new NodesFaultDetection(Settings.builder().put(settingsB).put(pingSettings).build(), + threadPool, serviceB, clusterState.getClusterName()); nodesFDB.setLocalNode(nodeB); final CountDownLatch pingSent = new CountDownLatch(1); nodesFDB.addListener(new NodesFaultDetection.Listener() { @@ -260,13 +271,12 @@ public class ZenFaultDetectionTests extends ESTestCase { } public void testMasterFaultDetectionNotSizeLimited() throws InterruptedException { - Settings.Builder settings = Settings.builder(); boolean shouldRetry = randomBoolean(); ClusterName clusterName = new ClusterName(randomAsciiOfLengthBetween(3, 20)); - settings + final Settings settings = Settings.builder() .put(FaultDetection.CONNECT_ON_NETWORK_DISCONNECT_SETTING.getKey(), shouldRetry) .put(FaultDetection.PING_INTERVAL_SETTING.getKey(), "1s") - .put("cluster.name", clusterName.value()); + .put("cluster.name", clusterName.value()).build(); final ClusterState stateNodeA = ClusterState.builder(clusterName).nodes(buildNodesForA(false)).build(); setState(clusterServiceA, stateNodeA); @@ -278,15 +288,15 @@ public class ZenFaultDetectionTests extends ESTestCase { serviceA.addTracer(pingProbeA); serviceB.addTracer(pingProbeB); - MasterFaultDetection masterFDNodeA = new MasterFaultDetection(settings.build(), threadPool, serviceA, - clusterServiceA); + MasterFaultDetection masterFDNodeA = new MasterFaultDetection(Settings.builder().put(settingsA).put(settings).build(), + threadPool, serviceA, clusterServiceA); masterFDNodeA.start(nodeB, "test"); final ClusterState stateNodeB = ClusterState.builder(clusterName).nodes(buildNodesForB(true)).build(); setState(clusterServiceB, stateNodeB); - MasterFaultDetection masterFDNodeB = new MasterFaultDetection(settings.build(), threadPool, serviceB, - clusterServiceB); + MasterFaultDetection masterFDNodeB = new MasterFaultDetection(Settings.builder().put(settingsB).put(settings).build(), + threadPool, serviceB, clusterServiceB); masterFDNodeB.start(nodeB, "test"); // let's do a few pings diff --git a/core/src/test/java/org/elasticsearch/discovery/zen/PublishClusterStateActionTests.java b/core/src/test/java/org/elasticsearch/discovery/zen/PublishClusterStateActionTests.java index 5b6668c6f5e..2cf623b702a 100644 --- a/core/src/test/java/org/elasticsearch/discovery/zen/PublishClusterStateActionTests.java +++ b/core/src/test/java/org/elasticsearch/discovery/zen/PublishClusterStateActionTests.java @@ -43,7 +43,6 @@ import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.discovery.Discovery; import org.elasticsearch.discovery.DiscoverySettings; -import org.elasticsearch.env.NodeEnvironment; import org.elasticsearch.node.Node; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.junit.annotations.TestLogging; @@ -168,11 +167,10 @@ public class PublishClusterStateActionTests extends ESTestCase { .build(); MockTransportService service = buildTransportService(settings, threadPool); - DiscoveryNode discoveryNode = DiscoveryNode.createLocal(settings, service.boundAddress().publishAddress(), - NodeEnvironment.generateNodeId(settings)); + DiscoveryNode discoveryNode = service.getLocalDiscoNode(); MockNode node = new MockNode(discoveryNode, service, listener, logger); node.action = buildPublishClusterStateAction(settings, service, () -> node.clusterState, node); - final CountDownLatch latch = new CountDownLatch(nodes.size() * 2 + 1); + final CountDownLatch latch = new CountDownLatch(nodes.size() * 2); TransportConnectionListener waitForConnection = new TransportConnectionListener() { @Override public void onNodeConnected(DiscoveryNode node) { @@ -190,7 +188,6 @@ public class PublishClusterStateActionTests extends ESTestCase { curNode.connectTo(node.discoveryNode); node.connectTo(curNode.discoveryNode); } - node.connectTo(node.discoveryNode); assertThat("failed to wait for all nodes to connect", latch.await(5, TimeUnit.SECONDS), equalTo(true)); for (MockNode curNode : nodes.values()) { curNode.service.removeConnectionListener(waitForConnection); diff --git a/core/src/test/java/org/elasticsearch/discovery/zen/UnicastZenPingTests.java b/core/src/test/java/org/elasticsearch/discovery/zen/UnicastZenPingTests.java index 40309a38f28..8eeced6cfcd 100644 --- a/core/src/test/java/org/elasticsearch/discovery/zen/UnicastZenPingTests.java +++ b/core/src/test/java/org/elasticsearch/discovery/zen/UnicastZenPingTests.java @@ -28,6 +28,7 @@ import org.elasticsearch.cluster.block.ClusterBlocks; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNode.Role; import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkAddress; import org.elasticsearch.common.network.NetworkService; @@ -44,6 +45,7 @@ import org.elasticsearch.test.VersionUtils; import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.ConnectTransportException; import org.elasticsearch.transport.ConnectionProfile; import org.elasticsearch.transport.MockTcpTransport; import org.elasticsearch.transport.Transport; @@ -148,7 +150,9 @@ public class UnicastZenPingTests extends ESTestCase { networkService, v) { @Override - public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) { + public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, + CheckedBiConsumer connectionValidator) + throws ConnectTransportException { throw new AssertionError("zen pings should never connect to node (got [" + node + "])"); } }; diff --git a/core/src/test/java/org/elasticsearch/transport/ConnectionProfileTests.java b/core/src/test/java/org/elasticsearch/transport/ConnectionProfileTests.java index 1785853d0e1..b18b57e3717 100644 --- a/core/src/test/java/org/elasticsearch/transport/ConnectionProfileTests.java +++ b/core/src/test/java/org/elasticsearch/transport/ConnectionProfileTests.java @@ -29,10 +29,15 @@ public class ConnectionProfileTests extends ESTestCase { public void testBuildConnectionProfile() { ConnectionProfile.Builder builder = new ConnectionProfile.Builder(); TimeValue connectTimeout = TimeValue.timeValueMillis(randomIntBetween(1, 10)); + TimeValue handshaketTimeout = TimeValue.timeValueMillis(randomIntBetween(1, 10)); final boolean setConnectTimeout = randomBoolean(); if (setConnectTimeout) { builder.setConnectTimeout(connectTimeout); } + final boolean setHandshakeTimeout = randomBoolean(); + if (setHandshakeTimeout) { + builder.setHandshakeTimeout(handshaketTimeout); + } builder.addConnections(1, TransportRequestOptions.Type.BULK); builder.addConnections(2, TransportRequestOptions.Type.STATE, TransportRequestOptions.Type.RECOVERY); builder.addConnections(3, TransportRequestOptions.Type.PING); @@ -44,12 +49,22 @@ public class ConnectionProfileTests extends ESTestCase { assertEquals("type [PING] is already registered", illegalArgumentException.getMessage()); builder.addConnections(4, TransportRequestOptions.Type.REG); ConnectionProfile build = builder.build(); + if (randomBoolean()) { + build = new ConnectionProfile.Builder(build).build(); + } assertEquals(10, build.getNumConnections()); if (setConnectTimeout) { assertEquals(connectTimeout, build.getConnectTimeout()); } else { assertNull(build.getConnectTimeout()); } + + if (setHandshakeTimeout) { + assertEquals(handshaketTimeout, build.getHandshakeTimeout()); + } else { + assertNull(build.getHandshakeTimeout()); + } + Integer[] array = new Integer[10]; for (int i = 0; i < array.length; i++) { array[i] = i; diff --git a/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java b/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java index 0b46843cdb7..c84fa38edc0 100644 --- a/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java +++ b/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.TestThreadPool; @@ -38,6 +39,8 @@ import java.util.List; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import static org.hamcrest.Matchers.equalTo; + /** Unit tests for TCPTransport */ public class TCPTransportTests extends ESTestCase { @@ -239,6 +242,38 @@ public class TCPTransportTests extends ESTestCase { } } + public void testConnectionProfileResolve() { + final ConnectionProfile defaultProfile = TcpTransport.buildDefaultConnectionProfile(Settings.EMPTY); + assertEquals(defaultProfile, TcpTransport.resolveConnectionProfile(null, defaultProfile)); + + final ConnectionProfile.Builder builder = new ConnectionProfile.Builder(); + builder.addConnections(randomIntBetween(0, 5), TransportRequestOptions.Type.BULK); + builder.addConnections(randomIntBetween(0, 5), TransportRequestOptions.Type.RECOVERY); + builder.addConnections(randomIntBetween(0, 5), TransportRequestOptions.Type.REG); + builder.addConnections(randomIntBetween(0, 5), TransportRequestOptions.Type.STATE); + builder.addConnections(randomIntBetween(0, 5), TransportRequestOptions.Type.PING); + + final boolean connectionTimeoutSet = randomBoolean(); + if (connectionTimeoutSet) { + builder.setConnectTimeout(TimeValue.timeValueMillis(randomNonNegativeLong())); + } + final boolean connectionHandshakeSet = randomBoolean(); + if (connectionHandshakeSet) { + builder.setHandshakeTimeout(TimeValue.timeValueMillis(randomNonNegativeLong())); + } + + final ConnectionProfile profile = builder.build(); + final ConnectionProfile resolved = TcpTransport.resolveConnectionProfile(profile, defaultProfile); + assertNotEquals(resolved, defaultProfile); + assertThat(resolved.getNumConnections(), equalTo(profile.getNumConnections())); + assertThat(resolved.getHandles(), equalTo(profile.getHandles())); + + assertThat(resolved.getConnectTimeout(), + equalTo(connectionTimeoutSet ? profile.getConnectTimeout() : defaultProfile.getConnectTimeout())); + assertThat(resolved.getHandshakeTimeout(), + equalTo(connectionHandshakeSet ? profile.getHandshakeTimeout() : defaultProfile.getHandshakeTimeout())); + } + public void testDefaultConnectionProfile() { ConnectionProfile profile = TcpTransport.buildDefaultConnectionProfile(Settings.EMPTY); assertEquals(13, profile.getNumConnections()); diff --git a/core/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java b/core/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java index 68dffa1ded9..e1cfc08dbd0 100644 --- a/core/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java +++ b/core/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java @@ -35,9 +35,6 @@ import org.junit.Before; import java.io.IOException; import java.util.concurrent.CountDownLatch; -import static java.util.Collections.emptyMap; -import static java.util.Collections.emptySet; - public class TransportActionProxyTests extends ESTestCase { protected ThreadPool threadPool; // we use always a non-alpha or beta version here otherwise minimumCompatibilityVersion will be different for the two used versions @@ -61,11 +58,11 @@ public class TransportActionProxyTests extends ESTestCase { super.setUp(); threadPool = new TestThreadPool(getClass().getName()); serviceA = buildService(version0); // this one supports dynamic tracer updates - nodeA = new DiscoveryNode("TS_A", serviceA.boundAddress().publishAddress(), emptyMap(), emptySet(), version0); + nodeA = serviceA.getLocalDiscoNode(); serviceB = buildService(version1); // this one doesn't support dynamic tracer updates - nodeB = new DiscoveryNode("TS_B", serviceB.boundAddress().publishAddress(), emptyMap(), emptySet(), version1); + nodeB = serviceB.getLocalDiscoNode(); serviceC = buildService(version1); // this one doesn't support dynamic tracer updates - nodeC = new DiscoveryNode("TS_C", serviceC.boundAddress().publishAddress(), emptyMap(), emptySet(), version1); + nodeC = serviceC.getLocalDiscoNode(); } @Override diff --git a/core/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java b/core/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java index c00f5fb07a5..3b6165adab1 100644 --- a/core/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java +++ b/core/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java @@ -161,6 +161,24 @@ public class TransportServiceHandshakeTests extends ESTestCase { assertFalse(handleA.transportService.nodeConnected(discoveryNode)); } + public void testNodeConnectWithDifferentNodeId() { + Settings settings = Settings.builder().put("cluster.name", "test").build(); + NetworkHandle handleA = startServices("TS_A", settings, Version.CURRENT); + NetworkHandle handleB = startServices("TS_B", settings, Version.CURRENT); + DiscoveryNode discoveryNode = new DiscoveryNode( + randomAsciiOfLength(10), + handleB.discoveryNode.getAddress(), + emptyMap(), + emptySet(), + handleB.discoveryNode.getVersion()); + ConnectTransportException ex = expectThrows(ConnectTransportException.class, () -> { + handleA.transportService.connectToNode(discoveryNode, MockTcpTransport.LIGHT_PROFILE); + }); + assertThat(ex.getMessage(), containsString("unexpected remote node")); + assertFalse(handleA.transportService.nodeConnected(discoveryNode)); + } + + private static class NetworkHandle { private TransportService transportService; private DiscoveryNode discoveryNode; diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4ScheduledPingTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4ScheduledPingTests.java index 2786077d084..8bfdbb739d7 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4ScheduledPingTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4ScheduledPingTests.java @@ -18,7 +18,6 @@ */ package org.elasticsearch.transport.netty4; -import org.elasticsearch.Version; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.lease.Releasables; @@ -46,8 +45,6 @@ import org.elasticsearch.transport.TransportSettings; import java.io.IOException; import java.util.Collections; -import static java.util.Collections.emptyMap; -import static java.util.Collections.emptySet; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; @@ -79,10 +76,8 @@ public class Netty4ScheduledPingTests extends ESTestCase { serviceB.start(); serviceB.acceptIncomingRequests(); - DiscoveryNode nodeA = - new DiscoveryNode("TS_A", "TS_A", serviceA.boundAddress().publishAddress(), emptyMap(), emptySet(), Version.CURRENT); - DiscoveryNode nodeB = - new DiscoveryNode("TS_B", "TS_B", serviceB.boundAddress().publishAddress(), emptyMap(), emptySet(), Version.CURRENT); + DiscoveryNode nodeA = serviceA.getLocalDiscoNode(); + DiscoveryNode nodeB = serviceB.getLocalDiscoNode(); serviceA.connectToNode(nodeB); serviceB.connectToNode(nodeA); diff --git a/test/framework/src/main/java/org/elasticsearch/test/ClusterServiceUtils.java b/test/framework/src/main/java/org/elasticsearch/test/ClusterServiceUtils.java index 07e7a25324b..01f626a1e2b 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ClusterServiceUtils.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ClusterServiceUtils.java @@ -34,7 +34,6 @@ import org.elasticsearch.threadpool.ThreadPool; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; -import java.util.List; import java.util.concurrent.CountDownLatch; import static junit.framework.TestCase.fail; @@ -48,8 +47,13 @@ public class ClusterServiceUtils { } public static ClusterService createClusterService(ThreadPool threadPool, DiscoveryNode localNode) { - ClusterService clusterService = new ClusterService(Settings.builder().put("cluster.name", "ClusterServiceTests").build(), - new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS), + return createClusterService(Settings.EMPTY, threadPool, localNode); + } + + public static ClusterService createClusterService(Settings settings, ThreadPool threadPool, DiscoveryNode localNode) { + ClusterService clusterService = new ClusterService( + Settings.builder().put("cluster.name", "ClusterServiceTests").put(settings).build(), + new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS), threadPool, () -> localNode); clusterService.setNodeConnectionsService(new NodeConnectionsService(Settings.EMPTY, null, null) { @Override diff --git a/test/framework/src/main/java/org/elasticsearch/test/transport/CapturingTransport.java b/test/framework/src/main/java/org/elasticsearch/test/transport/CapturingTransport.java index ffccdaac722..55519ec2af2 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/transport/CapturingTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/test/transport/CapturingTransport.java @@ -21,6 +21,7 @@ package org.elasticsearch.test.transport; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.Randomness; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.component.Lifecycle; @@ -238,7 +239,9 @@ public class CapturingTransport implements Transport { } @Override - public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) throws ConnectTransportException { + public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, + CheckedBiConsumer connectionValidator) + throws ConnectTransportException { } diff --git a/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java b/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java index b76f9fe6a70..117c168bc41 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java +++ b/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java @@ -22,6 +22,7 @@ package org.elasticsearch.test.transport; import org.elasticsearch.Version; import org.elasticsearch.cluster.ClusterModule; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.component.Lifecycle; @@ -99,7 +100,16 @@ public final class MockTransportService extends TransportService { NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(ClusterModule.getNamedWriteables()); final Transport transport = new MockTcpTransport(settings, threadPool, BigArrays.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(settings, Collections.emptyList()), version); - return new MockTransportService(settings, transport, threadPool, TransportService.NOOP_TRANSPORT_INTERCEPTOR, clusterSettings); + return createNewService(settings, transport, version, threadPool, clusterSettings); + } + + public static MockTransportService createNewService(Settings settings, Transport transport, Version version, ThreadPool threadPool, + @Nullable ClusterSettings clusterSettings) { + return new MockTransportService(settings, transport, threadPool, TransportService.NOOP_TRANSPORT_INTERCEPTOR, + boundAddress -> + new DiscoveryNode(Node.NODE_NAME_SETTING.get(settings), UUIDs.randomBase64UUID(), boundAddress.publishAddress(), + Node.NODE_ATTRIBUTES.get(settings).getAsMap(), DiscoveryNode.getRolesFromSettings(settings), version), + clusterSettings); } private final Transport original; @@ -198,7 +208,9 @@ public final class MockTransportService extends TransportService { addDelegate(transportAddress, new DelegateTransport(original) { @Override - public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) throws ConnectTransportException { + public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, + CheckedBiConsumer connectionValidator) + throws ConnectTransportException { if (original.nodeConnected(node) == false) { // connecting to an already connected node is a no-op throw new ConnectTransportException(node, "DISCONNECT: simulated"); @@ -244,8 +256,10 @@ public final class MockTransportService extends TransportService { addDelegate(transportAddress, new DelegateTransport(original) { @Override - public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) throws ConnectTransportException { - original.connectToNode(node, connectionProfile); + public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, + CheckedBiConsumer connectionValidator) + throws ConnectTransportException { + original.connectToNode(node, connectionProfile, connectionValidator); } @Override @@ -278,7 +292,9 @@ public final class MockTransportService extends TransportService { addDelegate(transportAddress, new DelegateTransport(original) { @Override - public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) throws ConnectTransportException { + public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, + CheckedBiConsumer connectionValidator) + throws ConnectTransportException { if (original.nodeConnected(node) == false) { // connecting to an already connected node is a no-op throw new ConnectTransportException(node, "UNRESPONSIVE: simulated"); @@ -323,14 +339,16 @@ public final class MockTransportService extends TransportService { } @Override - public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) throws ConnectTransportException { + public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, + CheckedBiConsumer connectionValidator) + throws ConnectTransportException { if (original.nodeConnected(node)) { // connecting to an already connected node is a no-op return; } TimeValue delay = getDelay(); if (delay.millis() <= 0) { - original.connectToNode(node, connectionProfile); + original.connectToNode(node, connectionProfile, connectionValidator); return; } @@ -339,7 +357,7 @@ public final class MockTransportService extends TransportService { try { if (delay.millis() < connectingTimeout.millis()) { Thread.sleep(delay.millis()); - original.connectToNode(node, connectionProfile); + original.connectToNode(node, connectionProfile, connectionValidator); } else { Thread.sleep(connectingTimeout.millis()); throw new ConnectTransportException(node, "UNRESPONSIVE: simulated"); @@ -486,10 +504,11 @@ public final class MockTransportService extends TransportService { return getTransport(node).nodeConnected(node); } - @Override - public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) throws ConnectTransportException { - getTransport(node).connectToNode(node, connectionProfile); + public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, + CheckedBiConsumer connectionValidator) + throws ConnectTransportException { + getTransport(node).connectToNode(node, connectionProfile, connectionValidator); } @Override @@ -542,8 +561,10 @@ public final class MockTransportService extends TransportService { } @Override - public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile) throws ConnectTransportException { - transport.connectToNode(node, connectionProfile); + public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, + CheckedBiConsumer connectionValidator) + throws ConnectTransportException { + transport.connectToNode(node, connectionProfile, connectionValidator); } @Override diff --git a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java index 865ec10430e..47e85fa58b7 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java @@ -24,6 +24,7 @@ import org.apache.logging.log4j.util.Supplier; import org.apache.lucene.util.CollectionUtil; import org.apache.lucene.util.Constants; import org.apache.lucene.util.IOUtils; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListenerResponseHandler; @@ -72,6 +73,7 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import static java.util.Collections.emptyMap; @@ -361,6 +363,101 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { assertThat(responseString.get(), equalTo("test")); } + public void testAdapterSendReceiveCallbacks() throws Exception { + final TransportRequestHandler requestHandler = (request, channel) -> { + try { + if (randomBoolean()) { + channel.sendResponse(TransportResponse.Empty.INSTANCE); + } else { + channel.sendResponse(new ElasticsearchException("simulated")); + } + } catch (IOException e) { + logger.error("Unexpected failure", e); + fail(e.getMessage()); + } + }; + serviceA.registerRequestHandler("action", TransportRequest.Empty::new, ThreadPool.Names.GENERIC, + requestHandler); + serviceB.registerRequestHandler("action", TransportRequest.Empty::new, ThreadPool.Names.GENERIC, + requestHandler); + + + class CountingTracer extends MockTransportService.Tracer { + AtomicInteger requestsReceived = new AtomicInteger(); + AtomicInteger requestsSent = new AtomicInteger(); + AtomicInteger responseReceived = new AtomicInteger(); + AtomicInteger responseSent = new AtomicInteger(); + @Override + public void receivedRequest(long requestId, String action) { + requestsReceived.incrementAndGet(); + } + + @Override + public void responseSent(long requestId, String action) { + responseSent.incrementAndGet(); + } + + @Override + public void responseSent(long requestId, String action, Throwable t) { + responseSent.incrementAndGet(); + } + + @Override + public void receivedResponse(long requestId, DiscoveryNode sourceNode, String action) { + responseReceived.incrementAndGet(); + } + + @Override + public void requestSent(DiscoveryNode node, long requestId, String action, TransportRequestOptions options) { + requestsSent.incrementAndGet(); + } + } + final CountingTracer tracerA = new CountingTracer(); + final CountingTracer tracerB = new CountingTracer(); + serviceA.addTracer(tracerA); + serviceB.addTracer(tracerB); + + try { + serviceA + .submitRequest(nodeB, "action", TransportRequest.Empty.INSTANCE, EmptyTransportResponseHandler.INSTANCE_SAME).get(); + } catch (ExecutionException e) { + assertThat(e.getCause(), instanceOf(ElasticsearchException.class)); + assertThat(ExceptionsHelper.unwrapCause(e.getCause()).getMessage(), equalTo("simulated")); + } + + // use assert busy as call backs are sometime called after the response have been sent + assertBusy(() -> { + assertThat(tracerA.requestsReceived.get(), equalTo(0)); + assertThat(tracerA.requestsSent.get(), equalTo(1)); + assertThat(tracerA.responseReceived.get(), equalTo(1)); + assertThat(tracerA.responseSent.get(), equalTo(0)); + assertThat(tracerB.requestsReceived.get(), equalTo(1)); + assertThat(tracerB.requestsSent.get(), equalTo(0)); + assertThat(tracerB.responseReceived.get(), equalTo(0)); + assertThat(tracerB.responseSent.get(), equalTo(1)); + }); + + try { + serviceA + .submitRequest(nodeA, "action", TransportRequest.Empty.INSTANCE, EmptyTransportResponseHandler.INSTANCE_SAME).get(); + } catch (ExecutionException e) { + assertThat(e.getCause(), instanceOf(ElasticsearchException.class)); + assertThat(ExceptionsHelper.unwrapCause(e.getCause()).getMessage(), equalTo("simulated")); + } + + // use assert busy as call backs are sometime called after the response have been sent + assertBusy(() -> { + assertThat(tracerA.requestsReceived.get(), equalTo(1)); + assertThat(tracerA.requestsSent.get(), equalTo(2)); + assertThat(tracerA.responseReceived.get(), equalTo(2)); + assertThat(tracerA.responseSent.get(), equalTo(1)); + assertThat(tracerB.requestsReceived.get(), equalTo(1)); + assertThat(tracerB.requestsSent.get(), equalTo(0)); + assertThat(tracerB.responseReceived.get(), equalTo(0)); + assertThat(tracerB.responseSent.get(), equalTo(1)); + }); + } + public void testVoidMessageCompressed() { serviceA.registerRequestHandler("sayHello", TransportRequest.Empty::new, ThreadPool.Names.GENERIC, (request, channel) -> { @@ -621,7 +718,7 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { MockTransportService newService = buildService("TS_B_" + i, version1, null); newService.registerRequestHandler("test", TestRequest::new, ThreadPool.Names.SAME, ignoringRequestHandler); serviceB = newService; - nodeB = new DiscoveryNode("TS_B_" + i, "TS_B", serviceB.boundAddress().publishAddress(), emptyMap(), emptySet(), version1); + nodeB = newService.getLocalDiscoNode(); serviceB.connectToNode(nodeA); serviceA.connectToNode(nodeB); } else if (serviceA.nodeConnected(nodeB)) { @@ -1467,42 +1564,42 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { channel.sendResponse(TransportResponse.Empty.INSTANCE); }); - DiscoveryNode node = - new DiscoveryNode("TS_TEST", "TS_TEST", service.boundAddress().publishAddress(), emptyMap(), emptySet(), version0); + DiscoveryNode node = service.getLocalNode(); serviceA.close(); serviceA = buildService("TS_A", version0, null, Settings.EMPTY, true, false); - serviceA.connectToNode(node); + try (Transport.Connection connection = serviceA.openConnection(node, null)) { + CountDownLatch latch = new CountDownLatch(1); + serviceA.sendRequest(connection, "action", new TestRequest(), TransportRequestOptions.EMPTY, + new TransportResponseHandler() { + @Override + public TestResponse newInstance() { + return new TestResponse(); + } - CountDownLatch latch = new CountDownLatch(1); - serviceA.sendRequest(node, "action", new TestRequest(), new TransportResponseHandler() { - @Override - public TestResponse newInstance() { - return new TestResponse(); - } + @Override + public void handleResponse(TestResponse response) { + latch.countDown(); + } - @Override - public void handleResponse(TestResponse response) { - latch.countDown(); - } + @Override + public void handleException(TransportException exp) { + latch.countDown(); + } - @Override - public void handleException(TransportException exp) { - latch.countDown(); - } + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + }); - @Override - public String executor() { - return ThreadPool.Names.SAME; - } - }); + assertFalse(requestProcessed.get()); - assertFalse(requestProcessed.get()); + service.acceptIncomingRequests(); + assertBusy(() -> assertTrue(requestProcessed.get())); - service.acceptIncomingRequests(); - assertBusy(() -> assertTrue(requestProcessed.get())); - - latch.await(); + latch.await(); + } } } @@ -1781,12 +1878,12 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { // connection with one connection and a large timeout -- should consume the one spot in the backlog queue try (TransportService service = buildService("TS_TPC", Version.CURRENT, null, Settings.EMPTY, true, false)) { - service.connectToNode(first, builder.build()); + IOUtils.close(service.openConnection(first, builder.build())); builder.setConnectTimeout(TimeValue.timeValueMillis(1)); final ConnectionProfile profile = builder.build(); // now with the 1ms timeout we got and test that is it's applied long startTime = System.nanoTime(); - ConnectTransportException ex = expectThrows(ConnectTransportException.class, () -> service.connectToNode(second, profile)); + ConnectTransportException ex = expectThrows(ConnectTransportException.class, () -> service.openConnection(second, profile)); final long now = System.nanoTime(); final long timeTaken = TimeValue.nsecToMSec(now - startTime); assertTrue("test didn't timeout quick enough, time taken: [" + timeTaken + "]", @@ -1867,13 +1964,13 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { assertEquals("handshake failed", exception.getCause().getMessage()); } - try (TransportService service = buildService("TS_TPC", Version.CURRENT, null)) { - DiscoveryNode node = - new DiscoveryNode("TS_TPC", "TS_TPC", service.boundAddress().publishAddress(), emptyMap(), emptySet(), version0); - serviceA.connectToNode(node); - TcpTransport.NodeChannels connection = originalTransport.getConnection(node); - Version version = originalTransport.executeHandshake(node, connection.channel(TransportRequestOptions.Type.PING), - TimeValue.timeValueSeconds(10)); + try (TransportService service = buildService("TS_TPC", Version.CURRENT, null); + TcpTransport.NodeChannels connection = originalTransport.openConnection( + new DiscoveryNode("TS_TPC", "TS_TPC", service.boundAddress().publishAddress(), emptyMap(), emptySet(), version0), + null + ) ) { + Version version = originalTransport.executeHandshake(connection.getNode(), + connection.channel(TransportRequestOptions.Type.PING), TimeValue.timeValueSeconds(10)); assertEquals(version, Version.CURRENT); } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java index 2bb6b87fe15..1b2384ba5fc 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java @@ -202,8 +202,7 @@ public class MockTcpTransport extends TcpTransport final InetSocketAddress address = node.getAddress().address(); // we just use a single connections configureSocket(socket); - final TimeValue connectTimeout = profile.getConnectTimeout() == null ? defaultConnectionProfile.getConnectTimeout() - : profile.getConnectTimeout(); + final TimeValue connectTimeout = profile.getConnectTimeout(); try { socket.connect(address, Math.toIntExact(connectTimeout.millis())); } catch (SocketTimeoutException ex) { diff --git a/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java b/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java index 2cc84c4c0cd..75d450b5d53 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java @@ -48,8 +48,8 @@ public class MockTcpTransportTests extends AbstractSimpleTransportTestCase { } } }; - MockTransportService mockTransportService = new MockTransportService(Settings.EMPTY, transport, threadPool, - TransportService.NOOP_TRANSPORT_INTERCEPTOR, clusterSettings); + MockTransportService mockTransportService = + MockTransportService.createNewService(Settings.EMPTY, transport, version, threadPool, clusterSettings); mockTransportService.start(); return mockTransportService; }