From 4c6bfe32a7651f20e96f81172f5a3ef3a5229813 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Thu, 11 Jun 2020 18:53:06 +0200 Subject: [PATCH] Fix possible NPE on search phase failure (#57952) When a search phase fails, we release the context of all successful shards. Successful shards that rewrite the request to match none will not create any context since #. This change ensures that we don't try to release a `null` context on these successful shards. Closes #57945 --- .../search/AbstractSearchAsyncAction.java | 16 +-- .../action/search/SearchTransportService.java | 3 +- .../search/SearchPhaseResult.java | 3 + .../action/search/SearchAsyncActionTests.java | 108 ++++++++++++++++++ 4 files changed, 122 insertions(+), 8 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java index f6f90cc7ccb..66548adccb5 100644 --- a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java @@ -559,13 +559,15 @@ abstract class AbstractSearchAsyncAction exten */ private void raisePhaseFailure(SearchPhaseExecutionException exception) { results.getSuccessfulResults().forEach((entry) -> { - try { - SearchShardTarget searchShardTarget = entry.getSearchShardTarget(); - Transport.Connection connection = getConnection(searchShardTarget.getClusterAlias(), searchShardTarget.getNodeId()); - sendReleaseSearchContext(entry.getContextId(), connection, searchShardTarget.getOriginalIndices()); - } catch (Exception inner) { - inner.addSuppressed(exception); - logger.trace("failed to release context", inner); + if (entry.getContextId() != null) { + try { + SearchShardTarget searchShardTarget = entry.getSearchShardTarget(); + Transport.Connection connection = getConnection(searchShardTarget.getClusterAlias(), searchShardTarget.getNodeId()); + sendReleaseSearchContext(entry.getContextId(), connection, searchShardTarget.getOriginalIndices()); + } catch (Exception inner) { + inner.addSuppressed(exception); + logger.trace("failed to release context", inner); + } } }); listener.onFailure(exception); diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java index 681177e4399..5cb39d68c39 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java @@ -58,6 +58,7 @@ import org.elasticsearch.transport.TransportService; import java.io.IOException; import java.util.HashMap; import java.util.Map; +import java.util.Objects; import java.util.function.BiFunction; /** @@ -199,7 +200,7 @@ public class SearchTransportService { private SearchContextId contextId; ScrollFreeContextRequest(SearchContextId contextId) { - this.contextId = contextId; + this.contextId = Objects.requireNonNull(contextId); } ScrollFreeContextRequest(StreamInput in) throws IOException { diff --git a/server/src/main/java/org/elasticsearch/search/SearchPhaseResult.java b/server/src/main/java/org/elasticsearch/search/SearchPhaseResult.java index 15fab25f01b..879110314a7 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchPhaseResult.java +++ b/server/src/main/java/org/elasticsearch/search/SearchPhaseResult.java @@ -19,6 +19,7 @@ package org.elasticsearch.search; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.search.fetch.FetchSearchResult; @@ -52,7 +53,9 @@ public abstract class SearchPhaseResult extends TransportResponse { /** * Returns the search context ID that is used to reference the search context on the executing node + * or null if no context was created. */ + @Nullable public SearchContextId getContextId() { return contextId; } diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java index dae2ae7ecbe..9620bc4876a 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java @@ -59,6 +59,7 @@ import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap; import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentSet; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.greaterThanOrEqualTo; public class SearchAsyncActionTests extends ESTestCase { @@ -376,6 +377,113 @@ public class SearchAsyncActionTests extends ESTestCase { executor.shutdown(); } + public void testFanOutAndFail() throws InterruptedException { + SearchRequest request = new SearchRequest(); + request.allowPartialSearchResults(true); + request.setMaxConcurrentShardRequests(randomIntBetween(1, 100)); + CountDownLatch latch = new CountDownLatch(1); + AtomicReference failure = new AtomicReference<>(); + ActionListener responseListener = ActionListener.wrap( + searchResponse -> { throw new AssertionError("unexpected response"); }, + exc -> { + failure.set(exc); + latch.countDown(); + }); + DiscoveryNode primaryNode = new DiscoveryNode("node_1", buildNewFakeTransportAddress(), Version.CURRENT); + DiscoveryNode replicaNode = new DiscoveryNode("node_2", buildNewFakeTransportAddress(), Version.CURRENT); + + Map> nodeToContextMap = newConcurrentMap(); + AtomicInteger contextIdGenerator = new AtomicInteger(0); + int numShards = randomIntBetween(2, 10); + GroupShardsIterator shardsIter = getShardsIter("idx", + new OriginalIndices(new String[]{"idx"}, SearchRequest.DEFAULT_INDICES_OPTIONS), + numShards, randomBoolean(), primaryNode, replicaNode); + AtomicInteger numFreedContext = new AtomicInteger(); + SearchTransportService transportService = new SearchTransportService(null, null) { + @Override + public void sendFreeContext(Transport.Connection connection, SearchContextId contextId, OriginalIndices originalIndices) { + assertNotNull(contextId); + numFreedContext.incrementAndGet(); + assertTrue(nodeToContextMap.containsKey(connection.getNode())); + assertTrue(nodeToContextMap.get(connection.getNode()).remove(contextId)); + } + }; + Map lookup = new HashMap<>(); + lookup.put(primaryNode.getId(), new MockConnection(primaryNode)); + lookup.put(replicaNode.getId(), new MockConnection(replicaNode)); + Map aliasFilters = Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY)); + ExecutorService executor = Executors.newFixedThreadPool(randomIntBetween(1, Runtime.getRuntime().availableProcessors())); + AbstractSearchAsyncAction asyncAction = + new AbstractSearchAsyncAction( + "test", + logger, + transportService, + (cluster, node) -> { + assert cluster == null : "cluster was not null: " + cluster; + return lookup.get(node); }, + aliasFilters, + Collections.emptyMap(), + Collections.emptyMap(), + executor, + request, + responseListener, + shardsIter, + new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), + ClusterState.EMPTY_STATE, + null, + new ArraySearchPhaseResults<>(shardsIter.size()), + request.getMaxConcurrentShardRequests(), + SearchResponse.Clusters.EMPTY) { + TestSearchResponse response = new TestSearchResponse(); + + @Override + protected void executePhaseOnShard(SearchShardIterator shardIt, + ShardRouting shard, + SearchActionListener listener) { + assertTrue("shard: " + shard.shardId() + " has been queried twice", response.queried.add(shard.shardId())); + Transport.Connection connection = getConnection(null, shard.currentNodeId()); + final TestSearchPhaseResult testSearchPhaseResult; + if (shard.shardId().id() == 0) { + testSearchPhaseResult = new TestSearchPhaseResult(null, connection.getNode()); + } else { + testSearchPhaseResult = new TestSearchPhaseResult(new SearchContextId(UUIDs.randomBase64UUID(), + contextIdGenerator.incrementAndGet()), connection.getNode()); + Set ids = nodeToContextMap.computeIfAbsent(connection.getNode(), (n) -> newConcurrentSet()); + ids.add(testSearchPhaseResult.getContextId()); + } + if (randomBoolean()) { + listener.onResponse(testSearchPhaseResult); + } else { + new Thread(() -> listener.onResponse(testSearchPhaseResult)).start(); + } + } + + @Override + protected SearchPhase getNextPhase(SearchPhaseResults results, + SearchPhaseContext context) { + return new SearchPhase("test") { + @Override + public void run() { + throw new RuntimeException("boom"); + } + }; + } + }; + asyncAction.start(); + latch.await(); + assertNotNull(failure.get()); + assertThat(failure.get().getCause().getMessage(), containsString("boom")); + assertFalse(nodeToContextMap.isEmpty()); + assertTrue(nodeToContextMap.toString(), nodeToContextMap.containsKey(primaryNode) || nodeToContextMap.containsKey(replicaNode)); + assertEquals(shardsIter.size()-1, numFreedContext.get()); + if (nodeToContextMap.containsKey(primaryNode)) { + assertTrue(nodeToContextMap.get(primaryNode).toString(), nodeToContextMap.get(primaryNode).isEmpty()); + } else { + assertTrue(nodeToContextMap.get(replicaNode).toString(), nodeToContextMap.get(replicaNode).isEmpty()); + } + executor.shutdown(); + } + public void testAllowPartialResults() throws InterruptedException { SearchRequest request = new SearchRequest(); request.allowPartialSearchResults(false);