Remove RemoteClusterConnection.ConnectedNodes (#44235)

This instead exposes the set of connected nodes on ConnectionManager.
This commit is contained in:
Yannick Welsch 2019-07-12 14:23:06 +02:00
parent 40d3c60d7a
commit 068286ca4b
8 changed files with 59 additions and 121 deletions

View File

@ -33,9 +33,11 @@ import org.elasticsearch.core.internal.io.IOUtils;
import java.io.Closeable; import java.io.Closeable;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
@ -238,6 +240,13 @@ public class ConnectionManager implements Closeable {
return connectedNodes.size(); return connectedNodes.size();
} }
/**
* Returns the set of nodes this manager is connected to.
*/
public Set<DiscoveryNode> connectedNodes() {
return Collections.unmodifiableSet(connectedNodes.keySet());
}
@Override @Override
public void close() { public void close() {
assert Transports.assertNotTransportThread("Closing ConnectionManager"); assert Transports.assertNotTransportThread("Closing ConnectionManager");

View File

@ -50,16 +50,15 @@ import java.net.InetSocketAddress;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue; import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.Semaphore; import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function; import java.util.function.Function;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.function.Supplier; import java.util.function.Supplier;
@ -84,7 +83,6 @@ final class RemoteClusterConnection implements TransportConnectionListener, Clos
private final TransportService transportService; private final TransportService transportService;
private final ConnectionManager connectionManager; private final ConnectionManager connectionManager;
private final ConnectedNodes connectedNodes;
private final String clusterAlias; private final String clusterAlias;
private final int maxNumRemoteConnections; private final int maxNumRemoteConnections;
private final Predicate<DiscoveryNode> nodePredicate; private final Predicate<DiscoveryNode> nodePredicate;
@ -123,7 +121,6 @@ final class RemoteClusterConnection implements TransportConnectionListener, Clos
this.nodePredicate = nodePredicate; this.nodePredicate = nodePredicate;
this.clusterAlias = clusterAlias; this.clusterAlias = clusterAlias;
this.connectionManager = connectionManager; this.connectionManager = connectionManager;
this.connectedNodes = new ConnectedNodes(clusterAlias);
this.seedNodes = Collections.unmodifiableList(seedNodes); this.seedNodes = Collections.unmodifiableList(seedNodes);
this.skipUnavailable = RemoteClusterService.REMOTE_CLUSTER_SKIP_UNAVAILABLE this.skipUnavailable = RemoteClusterService.REMOTE_CLUSTER_SKIP_UNAVAILABLE
.getConcreteSettingForNamespace(clusterAlias).get(settings); .getConcreteSettingForNamespace(clusterAlias).get(settings);
@ -176,8 +173,7 @@ final class RemoteClusterConnection implements TransportConnectionListener, Clos
@Override @Override
public void onNodeDisconnected(DiscoveryNode node) { public void onNodeDisconnected(DiscoveryNode node) {
boolean remove = connectedNodes.remove(node); if (connectionManager.size() < maxNumRemoteConnections) {
if (remove && connectedNodes.size() < maxNumRemoteConnections) {
// try to reconnect and fill up the slot of the disconnected node // try to reconnect and fill up the slot of the disconnected node
connectHandler.forceConnect(); connectHandler.forceConnect();
} }
@ -188,7 +184,7 @@ final class RemoteClusterConnection implements TransportConnectionListener, Clos
* will invoke the listener immediately. * will invoke the listener immediately.
*/ */
void ensureConnected(ActionListener<Void> voidActionListener) { void ensureConnected(ActionListener<Void> voidActionListener) {
if (connectedNodes.size() == 0) { if (connectionManager.size() == 0) {
connectHandler.connect(voidActionListener); connectHandler.connect(voidActionListener);
} else { } else {
voidActionListener.onResponse(null); voidActionListener.onResponse(null);
@ -466,14 +462,13 @@ final class RemoteClusterConnection implements TransportConnectionListener, Clos
} }
final DiscoveryNode handshakeNode = maybeAddProxyAddress(proxyAddress, handshakeResponse.getDiscoveryNode()); final DiscoveryNode handshakeNode = maybeAddProxyAddress(proxyAddress, handshakeResponse.getDiscoveryNode());
if (nodePredicate.test(handshakeNode) && connectedNodes.size() < maxNumRemoteConnections) { if (nodePredicate.test(handshakeNode) && manager.size() < maxNumRemoteConnections) {
PlainActionFuture.get(fut -> manager.connectToNode(handshakeNode, null, PlainActionFuture.get(fut -> manager.connectToNode(handshakeNode, null,
transportService.connectionValidator(handshakeNode), ActionListener.map(fut, x -> null))); transportService.connectionValidator(handshakeNode), ActionListener.map(fut, x -> null)));
if (remoteClusterName.get() == null) { if (remoteClusterName.get() == null) {
assert handshakeResponse.getClusterName().value() != null; assert handshakeResponse.getClusterName().value() != null;
remoteClusterName.set(handshakeResponse.getClusterName()); remoteClusterName.set(handshakeResponse.getClusterName());
} }
connectedNodes.add(handshakeNode);
} }
ClusterStateRequest request = new ClusterStateRequest(); ClusterStateRequest request = new ClusterStateRequest();
request.clear(); request.clear();
@ -580,12 +575,11 @@ final class RemoteClusterConnection implements TransportConnectionListener, Clos
Iterable<DiscoveryNode> nodesIter = nodes.getNodes()::valuesIt; Iterable<DiscoveryNode> nodesIter = nodes.getNodes()::valuesIt;
for (DiscoveryNode n : nodesIter) { for (DiscoveryNode n : nodesIter) {
DiscoveryNode node = maybeAddProxyAddress(proxyAddress, n); DiscoveryNode node = maybeAddProxyAddress(proxyAddress, n);
if (nodePredicate.test(node) && connectedNodes.size() < maxNumRemoteConnections) { if (nodePredicate.test(node) && connectionManager.size() < maxNumRemoteConnections) {
try { try {
// noop if node is connected // noop if node is connected
PlainActionFuture.get(fut -> connectionManager.connectToNode(node, null, PlainActionFuture.get(fut -> connectionManager.connectToNode(node, null,
transportService.connectionValidator(node), ActionListener.map(fut, x -> null))); transportService.connectionValidator(node), ActionListener.map(fut, x -> null)));
connectedNodes.add(node);
} catch (ConnectTransportException | IllegalStateException ex) { } catch (ConnectTransportException | IllegalStateException ex) {
// ISE if we fail the handshake with an version incompatible node // ISE if we fail the handshake with an version incompatible node
// fair enough we can't connect just move on // fair enough we can't connect just move on
@ -628,15 +622,20 @@ final class RemoteClusterConnection implements TransportConnectionListener, Clos
} }
boolean isNodeConnected(final DiscoveryNode node) { boolean isNodeConnected(final DiscoveryNode node) {
return connectedNodes.contains(node); return connectionManager.nodeConnected(node);
} }
private final AtomicLong nextNodeId = new AtomicLong();
DiscoveryNode getAnyConnectedNode() { DiscoveryNode getAnyConnectedNode() {
return connectedNodes.getAny(); List<DiscoveryNode> nodes = new ArrayList<>(connectionManager.connectedNodes());
} if (nodes.isEmpty()) {
throw new NoSuchRemoteClusterException(clusterAlias);
void addConnectedNode(DiscoveryNode node) { } else {
connectedNodes.add(node); long curr;
while ((curr = nextNodeId.incrementAndGet()) == Long.MIN_VALUE);
return nodes.get(Math.toIntExact(Math.floorMod(curr, nodes.size())));
}
} }
/** /**
@ -647,67 +646,13 @@ final class RemoteClusterConnection implements TransportConnectionListener, Clos
clusterAlias, clusterAlias,
seedNodes.stream().map(Tuple::v1).collect(Collectors.toList()), seedNodes.stream().map(Tuple::v1).collect(Collectors.toList()),
maxNumRemoteConnections, maxNumRemoteConnections,
connectedNodes.size(), getNumNodesConnected(),
initialConnectionTimeout, initialConnectionTimeout,
skipUnavailable); skipUnavailable);
} }
int getNumNodesConnected() { int getNumNodesConnected() {
return connectedNodes.size(); return connectionManager.size();
}
private static final class ConnectedNodes {
private final Set<DiscoveryNode> nodeSet = new HashSet<>();
private final String clusterAlias;
private Iterator<DiscoveryNode> currentIterator = null;
private ConnectedNodes(String clusterAlias) {
this.clusterAlias = clusterAlias;
}
public synchronized DiscoveryNode getAny() {
ensureIteratorAvailable();
if (currentIterator.hasNext()) {
return currentIterator.next();
} else {
throw new NoSuchRemoteClusterException(clusterAlias);
}
}
synchronized boolean remove(DiscoveryNode node) {
final boolean setRemoval = nodeSet.remove(node);
if (setRemoval) {
currentIterator = null;
}
return setRemoval;
}
synchronized boolean add(DiscoveryNode node) {
final boolean added = nodeSet.add(node);
if (added) {
currentIterator = null;
}
return added;
}
synchronized int size() {
return nodeSet.size();
}
synchronized boolean contains(DiscoveryNode node) {
return nodeSet.contains(node);
}
private synchronized void ensureIteratorAvailable() {
if (currentIterator == null) {
currentIterator = nodeSet.iterator();
} else if (currentIterator.hasNext() == false && nodeSet.isEmpty() == false) {
// iterator rollover
currentIterator = nodeSet.iterator();
}
}
} }
private static ConnectionManager createConnectionManager(ConnectionProfile connectionProfile, TransportService transportService) { private static ConnectionManager createConnectionManager(ConnectionProfile connectionProfile, TransportService transportService) {

View File

@ -166,7 +166,7 @@ public class TransportClientNodesServiceTests extends ESTestCase {
assert addr == null : "boundAddress: " + addr; assert addr == null : "boundAddress: " + addr;
return DiscoveryNode.createLocal(settings, buildNewFakeTransportAddress(), UUIDs.randomBase64UUID()); return DiscoveryNode.createLocal(settings, buildNewFakeTransportAddress(), UUIDs.randomBase64UUID());
}, null, Collections.emptySet()); }, null, Collections.emptySet());
transportService.addNodeConnectedBehavior((connectionManager, discoveryNode) -> false); transportService.addNodeConnectedBehavior(cm -> Collections.emptySet());
transportService.addGetConnectionBehavior((connectionManager, discoveryNode) -> { transportService.addGetConnectionBehavior((connectionManager, discoveryNode) -> {
// The FailAndRetryTransport does not use the connection profile // The FailAndRetryTransport does not use the connection profile
PlainActionFuture<Transport.Connection> future = PlainActionFuture.newFuture(); PlainActionFuture<Transport.Connection> future = PlainActionFuture.newFuture();

View File

@ -30,6 +30,7 @@ import org.elasticsearch.cluster.node.DiscoveryNodes.Builder;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.discovery.PeerFinder.TransportAddressConnector; import org.elasticsearch.discovery.PeerFinder.TransportAddressConnector;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.transport.CapturingTransport; import org.elasticsearch.test.transport.CapturingTransport;
@ -214,11 +215,9 @@ public class PeerFinderTests extends ESTestCase {
= new ConnectionManager(settings, capturingTransport); = new ConnectionManager(settings, capturingTransport);
StubbableConnectionManager connectionManager StubbableConnectionManager connectionManager
= new StubbableConnectionManager(innerConnectionManager, settings, capturingTransport); = new StubbableConnectionManager(innerConnectionManager, settings, capturingTransport);
connectionManager.setDefaultNodeConnectedBehavior((cm, discoveryNode) -> { connectionManager.setDefaultNodeConnectedBehavior(cm -> {
final boolean isConnected = connectedNodes.contains(discoveryNode); assertTrue(Sets.haveEmptyIntersection(connectedNodes, disconnectedNodes));
final boolean isDisconnected = disconnectedNodes.contains(discoveryNode); return connectedNodes;
assert isConnected != isDisconnected : discoveryNode + ": isConnected=" + isConnected + ", isDisconnected=" + isDisconnected;
return isConnected;
}); });
connectionManager.setDefaultGetConnectionBehavior((cm, discoveryNode) -> capturingTransport.createConnection(discoveryNode)); connectionManager.setDefaultGetConnectionBehavior((cm, discoveryNode) -> capturingTransport.createConnection(discoveryNode));
transportService = new TransportService(settings, capturingTransport, deterministicTaskQueue.getThreadPool(), transportService = new TransportService(settings, capturingTransport, deterministicTaskQueue.getThreadPool(),

View File

@ -34,6 +34,7 @@ import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNode;
@ -479,9 +480,10 @@ public class RemoteClusterConnectionTests extends ESTestCase {
public void testRemoteConnectionVersionMatchesTransportConnectionVersion() throws Exception { public void testRemoteConnectionVersionMatchesTransportConnectionVersion() throws Exception {
List<DiscoveryNode> knownNodes = new CopyOnWriteArrayList<>(); List<DiscoveryNode> knownNodes = new CopyOnWriteArrayList<>();
final Version previousVersion = VersionUtils.getPreviousVersion(); final Version previousVersion = randomValueOtherThan(Version.CURRENT, () -> VersionUtils.randomVersionBetween(random(),
try (MockTransportService seedTransport = startTransport("seed_node", knownNodes, previousVersion); Version.CURRENT.minimumCompatibilityVersion(), Version.CURRENT));
MockTransportService discoverableTransport = startTransport("discoverable_node", knownNodes, Version.CURRENT)) { try (MockTransportService seedTransport = startTransport("seed_node", knownNodes, Version.CURRENT);
MockTransportService discoverableTransport = startTransport("discoverable_node", knownNodes, previousVersion)) {
DiscoveryNode seedNode = seedTransport.getLocalDiscoNode(); DiscoveryNode seedNode = seedTransport.getLocalDiscoNode();
assertThat(seedNode, notNullValue()); assertThat(seedNode, notNullValue());
@ -520,12 +522,10 @@ public class RemoteClusterConnectionTests extends ESTestCase {
service.acceptIncomingRequests(); service.acceptIncomingRequests();
try (RemoteClusterConnection connection = new RemoteClusterConnection(Settings.EMPTY, "test-cluster", try (RemoteClusterConnection connection = new RemoteClusterConnection(Settings.EMPTY, "test-cluster",
seedNodes(seedNode), service, Integer.MAX_VALUE, n -> true, null, connectionManager)) { seedNodes(seedNode), service, Integer.MAX_VALUE, n -> true, null, connectionManager)) {
connection.addConnectedNode(seedNode); PlainActionFuture.get(fut -> connection.ensureConnected(ActionListener.map(fut, x -> null)));
for (DiscoveryNode node : knownNodes) {
final Transport.Connection transportConnection = connection.getConnection(node);
assertThat(transportConnection.getVersion(), equalTo(previousVersion));
}
assertThat(knownNodes, iterableWithSize(2)); assertThat(knownNodes, iterableWithSize(2));
assertThat(connection.getConnection(seedNode).getVersion(), equalTo(Version.CURRENT));
assertThat(connection.getConnection(oldVersionNode).getVersion(), equalTo(previousVersion));
} }
} }
} }
@ -1007,7 +1007,7 @@ public class RemoteClusterConnectionTests extends ESTestCase {
discoverableTransports.add(transportService); discoverableTransports.add(transportService);
} }
List<Tuple<String, Supplier<DiscoveryNode>>> seedNodes = randomSubsetOf(discoverableNodes); List<Tuple<String, Supplier<DiscoveryNode>>> seedNodes = new CopyOnWriteArrayList<>(randomSubsetOf(discoverableNodes));
Collections.shuffle(seedNodes, random()); Collections.shuffle(seedNodes, random());
try (MockTransportService service = MockTransportService.createNewService(Settings.EMPTY, Version.CURRENT, threadPool, null)) { try (MockTransportService service = MockTransportService.createNewService(Settings.EMPTY, Version.CURRENT, threadPool, null)) {
@ -1048,11 +1048,14 @@ public class RemoteClusterConnectionTests extends ESTestCase {
barrier.await(); barrier.await();
for (int j = 0; j < numDisconnects; j++) { for (int j = 0; j < numDisconnects; j++) {
if (randomBoolean()) { if (randomBoolean()) {
String node = "discoverable_node_added" + counter.incrementAndGet();
MockTransportService transportService = MockTransportService transportService =
startTransport("discoverable_node_added" + counter.incrementAndGet(), knownNodes, startTransport(node, knownNodes,
Version.CURRENT); Version.CURRENT);
discoverableTransports.add(transportService); discoverableTransports.add(transportService);
connection.addConnectedNode(transportService.getLocalDiscoNode()); seedNodes.add(Tuple.tuple(node, () -> transportService.getLocalDiscoNode()));
PlainActionFuture.get(fut -> connection.updateSeedNodes(null, seedNodes,
ActionListener.map(fut, x -> null)));
} else { } else {
DiscoveryNode node = randomFrom(discoverableNodes).v2().get(); DiscoveryNode node = randomFrom(discoverableNodes).v2().get();
connection.onNodeDisconnected(node); connection.onNodeDisconnected(node);
@ -1161,8 +1164,7 @@ public class RemoteClusterConnectionTests extends ESTestCase {
ConnectionManager delegate = new ConnectionManager(Settings.EMPTY, service.transport); ConnectionManager delegate = new ConnectionManager(Settings.EMPTY, service.transport);
StubbableConnectionManager connectionManager = new StubbableConnectionManager(delegate, Settings.EMPTY, service.transport); StubbableConnectionManager connectionManager = new StubbableConnectionManager(delegate, Settings.EMPTY, service.transport);
connectionManager.addNodeConnectedBehavior(connectedNode.getAddress(), (cm, discoveryNode) connectionManager.setDefaultNodeConnectedBehavior(cm -> Collections.singleton(connectedNode));
-> discoveryNode.equals(connectedNode));
connectionManager.addConnectBehavior(connectedNode.getAddress(), (cm, discoveryNode) -> { connectionManager.addConnectBehavior(connectedNode.getAddress(), (cm, discoveryNode) -> {
if (discoveryNode == connectedNode) { if (discoveryNode == connectedNode) {
@ -1174,7 +1176,7 @@ public class RemoteClusterConnectionTests extends ESTestCase {
service.acceptIncomingRequests(); service.acceptIncomingRequests();
try (RemoteClusterConnection connection = new RemoteClusterConnection(Settings.EMPTY, "test-cluster", try (RemoteClusterConnection connection = new RemoteClusterConnection(Settings.EMPTY, "test-cluster",
seedNodes(connectedNode), service, Integer.MAX_VALUE, n -> true, null, connectionManager)) { seedNodes(connectedNode), service, Integer.MAX_VALUE, n -> true, null, connectionManager)) {
connection.addConnectedNode(connectedNode); PlainActionFuture.get(fut -> connection.ensureConnected(ActionListener.map(fut, x -> null)));
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
//always a direct connection as the remote node is already connected //always a direct connection as the remote node is already connected
Transport.Connection remoteConnection = connection.getConnection(connectedNode); Transport.Connection remoteConnection = connection.getConnection(connectedNode);

View File

@ -81,7 +81,7 @@ public class MockTransport implements Transport, LifecycleComponent {
@Nullable ClusterSettings clusterSettings, Set<String> taskHeaders) { @Nullable ClusterSettings clusterSettings, Set<String> taskHeaders) {
StubbableConnectionManager connectionManager = new StubbableConnectionManager(new ConnectionManager(settings, this), StubbableConnectionManager connectionManager = new StubbableConnectionManager(new ConnectionManager(settings, this),
settings, this); settings, this);
connectionManager.setDefaultNodeConnectedBehavior((cm, discoveryNode) -> nodeConnected(discoveryNode)); connectionManager.setDefaultNodeConnectedBehavior(cm -> Collections.emptySet());
connectionManager.setDefaultGetConnectionBehavior((cm, discoveryNode) -> createConnection(discoveryNode)); connectionManager.setDefaultGetConnectionBehavior((cm, discoveryNode) -> createConnection(discoveryNode));
return new TransportService(settings, this, threadPool, interceptor, localNodeFactory, clusterSettings, taskHeaders, return new TransportService(settings, this, threadPool, interceptor, localNodeFactory, clusterSettings, taskHeaders,
connectionManager); connectionManager);
@ -186,10 +186,6 @@ public class MockTransport implements Transport, LifecycleComponent {
protected void onSendRequest(long requestId, String action, TransportRequest request, DiscoveryNode node) { protected void onSendRequest(long requestId, String action, TransportRequest request, DiscoveryNode node) {
} }
protected boolean nodeConnected(DiscoveryNode discoveryNode) {
return true;
}
@Override @Override
public TransportStats getStats() { public TransportStats getStats() {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();

View File

@ -500,15 +500,6 @@ public final class MockTransportService extends TransportService {
return connectionManager().setDefaultGetConnectionBehavior(behavior); return connectionManager().setDefaultGetConnectionBehavior(behavior);
} }
/**
* Adds a node connected behavior that is used for the given delegate address.
*
* @return {@code true} if no other node connected behavior was registered for this address before.
*/
public boolean addNodeConnectedBehavior(TransportAddress transportAddress, StubbableConnectionManager.NodeConnectedBehavior behavior) {
return connectionManager().addNodeConnectedBehavior(transportAddress, behavior);
}
/** /**
* Adds a node connected behavior that is the default node connected behavior. * Adds a node connected behavior that is the default node connected behavior.
* *

View File

@ -28,6 +28,7 @@ import org.elasticsearch.transport.ConnectionProfile;
import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportConnectionListener; import org.elasticsearch.transport.TransportConnectionListener;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ConcurrentMap;
@ -35,15 +36,13 @@ public class StubbableConnectionManager extends ConnectionManager {
private final ConnectionManager delegate; private final ConnectionManager delegate;
private final ConcurrentMap<TransportAddress, GetConnectionBehavior> getConnectionBehaviors; private final ConcurrentMap<TransportAddress, GetConnectionBehavior> getConnectionBehaviors;
private final ConcurrentMap<TransportAddress, NodeConnectedBehavior> nodeConnectedBehaviors;
private volatile GetConnectionBehavior defaultGetConnectionBehavior = ConnectionManager::getConnection; private volatile GetConnectionBehavior defaultGetConnectionBehavior = ConnectionManager::getConnection;
private volatile NodeConnectedBehavior defaultNodeConnectedBehavior = ConnectionManager::nodeConnected; private volatile NodeConnectedBehavior defaultNodeConnectedBehavior = ConnectionManager::connectedNodes;
public StubbableConnectionManager(ConnectionManager delegate, Settings settings, Transport transport) { public StubbableConnectionManager(ConnectionManager delegate, Settings settings, Transport transport) {
super(settings, transport); super(settings, transport);
this.delegate = delegate; this.delegate = delegate;
this.getConnectionBehaviors = new ConcurrentHashMap<>(); this.getConnectionBehaviors = new ConcurrentHashMap<>();
this.nodeConnectedBehaviors = new ConcurrentHashMap<>();
} }
public boolean addConnectBehavior(TransportAddress transportAddress, GetConnectionBehavior connectBehavior) { public boolean addConnectBehavior(TransportAddress transportAddress, GetConnectionBehavior connectBehavior) {
@ -56,10 +55,6 @@ public class StubbableConnectionManager extends ConnectionManager {
return prior == null; return prior == null;
} }
public boolean addNodeConnectedBehavior(TransportAddress transportAddress, NodeConnectedBehavior behavior) {
return nodeConnectedBehaviors.put(transportAddress, behavior) == null;
}
public boolean setDefaultNodeConnectedBehavior(NodeConnectedBehavior behavior) { public boolean setDefaultNodeConnectedBehavior(NodeConnectedBehavior behavior) {
NodeConnectedBehavior prior = defaultNodeConnectedBehavior; NodeConnectedBehavior prior = defaultNodeConnectedBehavior;
defaultNodeConnectedBehavior = behavior; defaultNodeConnectedBehavior = behavior;
@ -69,13 +64,11 @@ public class StubbableConnectionManager extends ConnectionManager {
public void clearBehaviors() { public void clearBehaviors() {
defaultGetConnectionBehavior = ConnectionManager::getConnection; defaultGetConnectionBehavior = ConnectionManager::getConnection;
getConnectionBehaviors.clear(); getConnectionBehaviors.clear();
defaultNodeConnectedBehavior = ConnectionManager::nodeConnected; defaultNodeConnectedBehavior = ConnectionManager::connectedNodes;
nodeConnectedBehaviors.clear();
} }
public void clearBehavior(TransportAddress transportAddress) { public void clearBehavior(TransportAddress transportAddress) {
getConnectionBehaviors.remove(transportAddress); getConnectionBehaviors.remove(transportAddress);
nodeConnectedBehaviors.remove(transportAddress);
} }
@Override @Override
@ -92,9 +85,12 @@ public class StubbableConnectionManager extends ConnectionManager {
@Override @Override
public boolean nodeConnected(DiscoveryNode node) { public boolean nodeConnected(DiscoveryNode node) {
TransportAddress address = node.getAddress(); return defaultNodeConnectedBehavior.connectedNodes(delegate).contains(node);
NodeConnectedBehavior behavior = nodeConnectedBehaviors.getOrDefault(address, defaultNodeConnectedBehavior); }
return behavior.nodeConnected(delegate, node);
@Override
public Set<DiscoveryNode> connectedNodes() {
return defaultNodeConnectedBehavior.connectedNodes(delegate);
} }
@Override @Override
@ -136,6 +132,6 @@ public class StubbableConnectionManager extends ConnectionManager {
@FunctionalInterface @FunctionalInterface
public interface NodeConnectedBehavior { public interface NodeConnectedBehavior {
boolean nodeConnected(ConnectionManager connectionManager, DiscoveryNode discoveryNode); Set<DiscoveryNode> connectedNodes(ConnectionManager connectionManager);
} }
} }