diff --git a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java index 45bfb099f2b..fdccfad7b47 100644 --- a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java @@ -228,7 +228,7 @@ abstract class AbstractSearchAsyncAction exten results.getSuccessfulResults().forEach((entry) -> { try { SearchShardTarget searchShardTarget = entry.getSearchShardTarget(); - Transport.Connection connection = getConnection(null, searchShardTarget.getNodeId()); + Transport.Connection connection = getConnection(searchShardTarget.getClusterAlias(), searchShardTarget.getNodeId()); sendReleaseSearchContext(entry.getRequestId(), connection, searchShardTarget.getOriginalIndices()); } catch (Exception inner) { inner.addSuppressed(exception); @@ -281,14 +281,12 @@ abstract class AbstractSearchAsyncAction exten @Override public final SearchResponse buildSearchResponse(InternalSearchResponse internalSearchResponse, String scrollId) { - ShardSearchFailure[] failures = buildShardFailures(); Boolean allowPartialResults = request.allowPartialSearchResults(); assert allowPartialResults != null : "SearchRequest missing setting for allowPartialSearchResults"; if (allowPartialResults == false && failures.length > 0){ raisePhaseFailure(new SearchPhaseExecutionException("", "Shard failures", null, failures)); } - return new SearchResponse(internalSearchResponse, scrollId, getNumShards(), successfulOps.get(), skippedOps.get(), buildTookInMillis(), failures, clusters); } diff --git a/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java index 70f70268a0a..16df17bef1a 100644 --- a/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java @@ -19,32 +19,49 @@ package org.elasticsearch.action.search; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.cluster.routing.GroupShardsIterator; import org.elasticsearch.cluster.routing.ShardRouting; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.index.Index; import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.search.SearchPhaseResult; +import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.internal.AliasFilter; +import org.elasticsearch.search.internal.InternalSearchResponse; import org.elasticsearch.search.internal.ShardSearchTransportRequest; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.Transport; +import java.util.ArrayList; import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CopyOnWriteArraySet; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.instanceOf; public class AbstractSearchAsyncActionTests extends ESTestCase { - private AbstractSearchAsyncAction createAction( - final boolean controlled, - final AtomicLong expected) { + private final List> resolvedNodes = new ArrayList<>(); + private final Set releasedContexts = new CopyOnWriteArraySet<>(); + private AbstractSearchAsyncAction createAction(SearchRequest request, + InitialSearchPhase.ArraySearchPhaseResults results, + ActionListener listener, + final boolean controlled, + final AtomicLong expected) { final Runnable runnable; final TransportSearchAction.SearchTimeProvider timeProvider; if (controlled) { @@ -61,18 +78,20 @@ public class AbstractSearchAsyncActionTests extends ESTestCase { System::nanoTime); } - final SearchRequest request = new SearchRequest(); - request.allowPartialSearchResults(true); - request.preference("_shards:1,3"); - return new AbstractSearchAsyncAction("test", null, null, null, + BiFunction nodeIdToConnection = (cluster, node) -> { + resolvedNodes.add(Tuple.tuple(cluster, node)); + return null; + }; + + return new AbstractSearchAsyncAction("test", null, null, nodeIdToConnection, Collections.singletonMap("foo", new AliasFilter(new MatchAllQueryBuilder())), Collections.singletonMap("foo", 2.0f), - Collections.singletonMap("name", Sets.newHashSet("bar", "baz")),null, request, null, + Collections.singletonMap("name", Sets.newHashSet("bar", "baz")), null, request, listener, new GroupShardsIterator<>( Collections.singletonList( new SearchShardIterator(null, null, Collections.emptyList(), null) ) ), timeProvider, 0, null, - new InitialSearchPhase.ArraySearchPhaseResults<>(10), request.getMaxConcurrentShardRequests(), + results, request.getMaxConcurrentShardRequests(), SearchResponse.Clusters.EMPTY) { @Override protected SearchPhase getNextPhase(final SearchPhaseResults results, final SearchPhaseContext context) { @@ -89,6 +108,11 @@ public class AbstractSearchAsyncActionTests extends ESTestCase { runnable.run(); return super.buildTookInMillis(); } + + @Override + public void sendReleaseSearchContext(long contextId, Transport.Connection connection, OriginalIndices originalIndices) { + releasedContexts.add(contextId); + } }; } @@ -102,7 +126,8 @@ public class AbstractSearchAsyncActionTests extends ESTestCase { private void runTestTook(final boolean controlled) { final AtomicLong expected = new AtomicLong(); - AbstractSearchAsyncAction action = createAction(controlled, expected); + AbstractSearchAsyncAction action = createAction(new SearchRequest(), + new InitialSearchPhase.ArraySearchPhaseResults<>(10), null, controlled, expected); final long actual = action.buildTookInMillis(); if (controlled) { // with a controlled clock, we can assert the exact took time @@ -114,8 +139,10 @@ public class AbstractSearchAsyncActionTests extends ESTestCase { } public void testBuildShardSearchTransportRequest() { + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(randomBoolean()).preference("_shards:1,3"); final AtomicLong expected = new AtomicLong(); - AbstractSearchAsyncAction action = createAction(false, expected); + AbstractSearchAsyncAction action = createAction(searchRequest, + new InitialSearchPhase.ArraySearchPhaseResults<>(10), null, false, expected); String clusterAlias = randomBoolean() ? null : randomAlphaOfLengthBetween(5, 10); SearchShardIterator iterator = new SearchShardIterator(clusterAlias, new ShardId(new Index("name", "foo"), 1), Collections.emptyList(), new OriginalIndices(new String[] {"name", "name1"}, IndicesOptions.strictExpand())); @@ -129,4 +156,114 @@ public class AbstractSearchAsyncActionTests extends ESTestCase { assertEquals("_shards:1,3", shardSearchTransportRequest.preference()); assertEquals(clusterAlias, shardSearchTransportRequest.getClusterAlias()); } + + public void testBuildSearchResponse() { + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(randomBoolean()); + AbstractSearchAsyncAction action = createAction(searchRequest, + new InitialSearchPhase.ArraySearchPhaseResults<>(10), null, false, new AtomicLong()); + String scrollId = randomBoolean() ? null : randomAlphaOfLengthBetween(5, 10); + InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty(); + SearchResponse searchResponse = action.buildSearchResponse(internalSearchResponse, scrollId); + assertEquals(scrollId, searchResponse.getScrollId()); + assertSame(searchResponse.getAggregations(), internalSearchResponse.aggregations()); + assertSame(searchResponse.getSuggest(), internalSearchResponse.suggest()); + assertSame(searchResponse.getProfileResults(), internalSearchResponse.profile()); + assertSame(searchResponse.getHits(), internalSearchResponse.hits()); + } + + public void testBuildSearchResponseAllowPartialFailures() { + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); + AbstractSearchAsyncAction action = createAction(searchRequest, + new InitialSearchPhase.ArraySearchPhaseResults<>(10), null, false, new AtomicLong()); + action.onShardFailure(0, new SearchShardTarget("node", new ShardId("index", "index-uuid", 0), null, OriginalIndices.NONE), + new IllegalArgumentException()); + String scrollId = randomBoolean() ? null : randomAlphaOfLengthBetween(5, 10); + InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty(); + SearchResponse searchResponse = action.buildSearchResponse(internalSearchResponse, scrollId); + assertEquals(scrollId, searchResponse.getScrollId()); + assertSame(searchResponse.getAggregations(), internalSearchResponse.aggregations()); + assertSame(searchResponse.getSuggest(), internalSearchResponse.suggest()); + assertSame(searchResponse.getProfileResults(), internalSearchResponse.profile()); + assertSame(searchResponse.getHits(), internalSearchResponse.hits()); + } + + public void testBuildSearchResponseDisallowPartialFailures() { + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(false); + AtomicReference exception = new AtomicReference<>(); + ActionListener listener = ActionListener.wrap(response -> fail("onResponse should not be called"), exception::set); + Set requestIds = new HashSet<>(); + List> nodeLookups = new ArrayList<>(); + int numFailures = randomIntBetween(1, 5); + InitialSearchPhase.ArraySearchPhaseResults phaseResults = phaseResults(requestIds, nodeLookups, numFailures); + AbstractSearchAsyncAction action = createAction(searchRequest, phaseResults, listener, false, new AtomicLong()); + for (int i = 0; i < numFailures; i++) { + ShardId failureShardId = new ShardId("index", "index-uuid", i); + String failureClusterAlias = randomBoolean() ? null : randomAlphaOfLengthBetween(5, 10); + String failureNodeId = randomAlphaOfLengthBetween(5, 10); + action.onShardFailure(i, new SearchShardTarget(failureNodeId, failureShardId, failureClusterAlias, OriginalIndices.NONE), + new IllegalArgumentException()); + } + action.buildSearchResponse(InternalSearchResponse.empty(), randomBoolean() ? null : randomAlphaOfLengthBetween(5, 10)); + assertThat(exception.get(), instanceOf(SearchPhaseExecutionException.class)); + SearchPhaseExecutionException searchPhaseExecutionException = (SearchPhaseExecutionException)exception.get(); + assertEquals(0, searchPhaseExecutionException.getSuppressed().length); + assertEquals(numFailures, searchPhaseExecutionException.shardFailures().length); + for (ShardSearchFailure shardSearchFailure : searchPhaseExecutionException.shardFailures()) { + assertThat(shardSearchFailure.getCause(), instanceOf(IllegalArgumentException.class)); + } + assertEquals(nodeLookups, resolvedNodes); + assertEquals(requestIds, releasedContexts); + } + + public void testOnPhaseFailure() { + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(false); + AtomicReference exception = new AtomicReference<>(); + ActionListener listener = ActionListener.wrap(response -> fail("onResponse should not be called"), exception::set); + Set requestIds = new HashSet<>(); + List> nodeLookups = new ArrayList<>(); + InitialSearchPhase.ArraySearchPhaseResults phaseResults = phaseResults(requestIds, nodeLookups, 0); + AbstractSearchAsyncAction action = createAction(searchRequest, phaseResults, listener, false, new AtomicLong()); + action.onPhaseFailure(new SearchPhase("test") { + @Override + public void run() { + + } + }, "message", null); + assertThat(exception.get(), instanceOf(SearchPhaseExecutionException.class)); + SearchPhaseExecutionException searchPhaseExecutionException = (SearchPhaseExecutionException)exception.get(); + assertEquals("message", searchPhaseExecutionException.getMessage()); + assertEquals("test", searchPhaseExecutionException.getPhaseName()); + assertEquals(0, searchPhaseExecutionException.shardFailures().length); + assertEquals(0, searchPhaseExecutionException.getSuppressed().length); + assertEquals(nodeLookups, resolvedNodes); + assertEquals(requestIds, releasedContexts); + } + + private static InitialSearchPhase.ArraySearchPhaseResults phaseResults(Set requestIds, + List> nodeLookups, + int numFailures) { + int numResults = randomIntBetween(1, 10); + InitialSearchPhase.ArraySearchPhaseResults phaseResults = + new InitialSearchPhase.ArraySearchPhaseResults<>(numResults + numFailures); + + for (int i = 0; i < numResults; i++) { + long requestId = randomLong(); + requestIds.add(requestId); + SearchPhaseResult phaseResult = new PhaseResult(requestId); + String resultClusterAlias = randomBoolean() ? null : randomAlphaOfLengthBetween(5, 10); + String resultNodeId = randomAlphaOfLengthBetween(5, 10); + ShardId resultShardId = new ShardId("index", "index-uuid", i); + nodeLookups.add(Tuple.tuple(resultClusterAlias, resultNodeId)); + phaseResult.setSearchShardTarget(new SearchShardTarget(resultNodeId, resultShardId, resultClusterAlias, OriginalIndices.NONE)); + phaseResult.setShardIndex(i); + phaseResults.consumeResult(phaseResult); + } + return phaseResults; + } + + private static final class PhaseResult extends SearchPhaseResult { + PhaseResult(long requestId) { + this.requestId = requestId; + } + } }