Preserve thread context when connecting to remote cluster (#31574)

Establishing remote cluster connections uses a queue to coordinate multiple concurrent connect
attempts. Connect attempts can be initiated by user triggered searches as well as by system events 
(e.g. when nodes disconnect). Multiple such concurrent events can lead to the connectListener of
one event to be called under the thread context of another connect attempt. This can lead to the
situation as seen in #31462 where the connect listener is executed under the system context, which 
breaks when fetching the search shards from the remote cluster.

Closes #31462
This commit is contained in:
Yannick Welsch 2018-06-27 10:46:42 +02:00 committed by GitHub
parent 4fc833b1de
commit b7246199db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 64 additions and 2 deletions

View File

@ -29,6 +29,7 @@ import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse
import org.elasticsearch.action.admin.cluster.state.ClusterStateAction;
import org.elasticsearch.action.admin.cluster.state.ClusterStateRequest;
import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
@ -369,9 +370,11 @@ final class RemoteClusterConnection extends AbstractComponent implements Transpo
private void connect(ActionListener<Void> connectListener, boolean forceRun) {
final boolean runConnect;
final Collection<ActionListener<Void>> toNotify;
final ActionListener<Void> listener = connectListener == null ? null :
ContextPreservingActionListener.wrapPreservingContext(connectListener, transportService.getThreadPool().getThreadContext());
synchronized (queue) {
if (connectListener != null && queue.offer(connectListener) == false) {
connectListener.onFailure(new RejectedExecutionException("connect queue is full"));
if (listener != null && queue.offer(listener) == false) {
listener.onFailure(new RejectedExecutionException("connect queue is full"));
return;
}
if (forceRun == false && queue.isEmpty()) {

View File

@ -42,6 +42,7 @@ import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.CancellableThreads;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.core.internal.io.IOUtils;
@ -555,6 +556,64 @@ public class RemoteClusterConnectionTests extends ESTestCase {
}
}
public void testFetchShardsThreadContextHeader() throws Exception {
List<DiscoveryNode> knownNodes = new CopyOnWriteArrayList<>();
try (MockTransportService seedTransport = startTransport("seed_node", knownNodes, Version.CURRENT);
MockTransportService discoverableTransport = startTransport("discoverable_node", knownNodes, Version.CURRENT)) {
DiscoveryNode seedNode = seedTransport.getLocalDiscoNode();
knownNodes.add(seedTransport.getLocalDiscoNode());
knownNodes.add(discoverableTransport.getLocalDiscoNode());
Collections.shuffle(knownNodes, random());
try (MockTransportService service = MockTransportService.createNewService(Settings.EMPTY, Version.CURRENT, threadPool, null)) {
service.start();
service.acceptIncomingRequests();
List<DiscoveryNode> nodes = Collections.singletonList(seedNode);
try (RemoteClusterConnection connection = new RemoteClusterConnection(Settings.EMPTY, "test-cluster",
nodes, service, Integer.MAX_VALUE, n -> true)) {
SearchRequest request = new SearchRequest("test-index");
Thread[] threads = new Thread[10];
for (int i = 0; i < threads.length; i++) {
final String threadId = Integer.toString(i);
threads[i] = new Thread(() -> {
ThreadContext threadContext = seedTransport.threadPool.getThreadContext();
threadContext.putHeader("threadId", threadId);
AtomicReference<ClusterSearchShardsResponse> reference = new AtomicReference<>();
AtomicReference<Exception> failReference = new AtomicReference<>();
final ClusterSearchShardsRequest searchShardsRequest = new ClusterSearchShardsRequest("test-index")
.indicesOptions(request.indicesOptions()).local(true).preference(request.preference())
.routing(request.routing());
CountDownLatch responseLatch = new CountDownLatch(1);
connection.fetchSearchShards(searchShardsRequest,
new LatchedActionListener<>(ActionListener.wrap(
resp -> {
reference.set(resp);
assertEquals(threadId, seedTransport.threadPool.getThreadContext().getHeader("threadId"));
},
failReference::set), responseLatch));
try {
responseLatch.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
assertNull(failReference.get());
assertNotNull(reference.get());
ClusterSearchShardsResponse clusterSearchShardsResponse = reference.get();
assertEquals(knownNodes, Arrays.asList(clusterSearchShardsResponse.getNodes()));
});
}
for (int i = 0; i < threads.length; i++) {
threads[i].start();
}
for (int i = 0; i < threads.length; i++) {
threads[i].join();
}
assertTrue(connection.assertNoRunningConnections());
}
}
}
}
public void testFetchShardsSkipUnavailable() throws Exception {
List<DiscoveryNode> knownNodes = new CopyOnWriteArrayList<>();
try (MockTransportService seedTransport = startTransport("seed_node", knownNodes, Version.CURRENT)) {