Async Search: correct shards counting (#55758)

Async search allows users to retrieve partial results for a running search. For partial results, the number of successful shards does not include the skipped shards, while the response returned to users should.

Also, we recently had a bug where async search would miss tracking shard failures, which would have been caught if we had assertions in place that verified that whenever we get the last response, the number of failures included in it is the same as the failures that were tracked through the listener notifications.
This commit is contained in:
Luca Cavanna 2020-05-06 12:05:10 +02:00
parent 07ad742b60
commit 9a9cb68e83
3 changed files with 64 additions and 24 deletions

View File

@ -94,7 +94,8 @@ class MutableSearchResponse {
throw new IllegalStateException("received partial response out of order: " throw new IllegalStateException("received partial response out of order: "
+ reducePhase + " < " + this.reducePhase); + reducePhase + " < " + this.reducePhase);
} }
this.successfulShards = successfulShards; //when we get partial results skipped shards are not included in the provided number of successful shards
this.successfulShards = successfulShards + skippedShards;
this.totalHits = totalHits; this.totalHits = totalHits;
this.reducedAggsSource = reducedAggs; this.reducedAggsSource = reducedAggs;
this.reducePhase = reducePhase; this.reducePhase = reducePhase;
@ -106,6 +107,11 @@ class MutableSearchResponse {
*/ */
synchronized void updateFinalResponse(SearchResponse response) { synchronized void updateFinalResponse(SearchResponse response) {
failIfFrozen(); failIfFrozen();
assert response.getTotalShards() == totalShards : "received number of total shards differs from the one " +
"notified through onListShards";
assert response.getSkippedShards() == skippedShards : "received number of skipped shards differs from the one " +
"notified through onListShards";
assert response.getFailedShards() == buildShardFailures().length : "number of tracked failures differs from failed shards";
// copy the response headers from the current context // copy the response headers from the current context
this.responseHeaders = threadContext.getResponseHeaders(); this.responseHeaders = threadContext.getResponseHeaders();
this.finalResponse = response; this.finalResponse = response;
@ -121,6 +127,8 @@ class MutableSearchResponse {
failIfFrozen(); failIfFrozen();
// copy the response headers from the current context // copy the response headers from the current context
this.responseHeaders = threadContext.getResponseHeaders(); this.responseHeaders = threadContext.getResponseHeaders();
//note that when search fails, we may have gotten partial results before the failure. In that case async
// search will return an error plus the last partial results that were collected.
this.isPartial = true; this.isPartial = true;
this.failure = ElasticsearchException.guessRootCauses(exc)[0]; this.failure = ElasticsearchException.guessRootCauses(exc)[0];
this.frozen = true; this.frozen = true;

View File

@ -121,7 +121,7 @@ public class AsyncSearchResponseTests extends ESTestCase {
long tookInMillis = randomNonNegativeLong(); long tookInMillis = randomNonNegativeLong();
int totalShards = randomIntBetween(1, Integer.MAX_VALUE); int totalShards = randomIntBetween(1, Integer.MAX_VALUE);
int successfulShards = randomIntBetween(0, totalShards); int successfulShards = randomIntBetween(0, totalShards);
int skippedShards = totalShards - successfulShards; int skippedShards = randomIntBetween(0, successfulShards);
InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty(); InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty();
return new SearchResponse(internalSearchResponse, null, totalShards, return new SearchResponse(internalSearchResponse, null, totalShards,
successfulShards, skippedShards, tookInMillis, ShardSearchFailure.EMPTY_ARRAY, SearchResponse.Clusters.EMPTY); successfulShards, skippedShards, tookInMillis, ShardSearchFailure.EMPTY_ARRAY, SearchResponse.Clusters.EMPTY);

View File

@ -134,24 +134,58 @@ public class AsyncSearchTaskTests extends ESTestCase {
for (int i = 0; i < numSkippedShards; i++) { for (int i = 0; i < numSkippedShards; i++) {
skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1))); skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1)));
} }
int totalShards = numShards + numSkippedShards;
int numShardFailures = 0;
task.getSearchProgressActionListener().onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false); task.getSearchProgressActionListener().onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false);
for (int i = 0; i < numShards; i++) { for (int i = 0; i < numShards; i++) {
task.getSearchProgressActionListener().onPartialReduce(shards.subList(i, i+1), task.getSearchProgressActionListener().onPartialReduce(shards.subList(i, i+1),
new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0); new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
assertCompletionListeners(task, numShards+numSkippedShards, numSkippedShards, numShardFailures, true); assertCompletionListeners(task, totalShards, 1 + numSkippedShards, numSkippedShards, 0, true);
} }
task.getSearchProgressActionListener().onFinalReduce(shards, task.getSearchProgressActionListener().onFinalReduce(shards,
new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0); new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
assertCompletionListeners(task, numShards+numSkippedShards, numSkippedShards, numShardFailures, true); assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, 0, true);
((AsyncSearchTask.Listener)task.getProgressListener()).onResponse( ((AsyncSearchTask.Listener)task.getProgressListener()).onResponse(
newSearchResponse(numShards+numSkippedShards, numShards, numSkippedShards)); newSearchResponse(totalShards, totalShards, numSkippedShards));
assertCompletionListeners(task, numShards+numSkippedShards, assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, 0, false);
numSkippedShards, numShardFailures, false);
} }
public void testWithFetchFailures() throws InterruptedException { public void testWithFetchFailures() throws InterruptedException {
AsyncSearchTask task = createAsyncSearchTask();
int numShards = randomIntBetween(2, 10);
List<SearchShard> shards = new ArrayList<>();
for (int i = 0; i < numShards; i++) {
shards.add(new SearchShard(null, new ShardId("0", "0", 1)));
}
List<SearchShard> skippedShards = new ArrayList<>();
int numSkippedShards = randomIntBetween(0, 10);
for (int i = 0; i < numSkippedShards; i++) {
skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1)));
}
int totalShards = numShards + numSkippedShards;
task.getSearchProgressActionListener().onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false);
for (int i = 0; i < numShards; i++) {
task.getSearchProgressActionListener().onPartialReduce(shards.subList(i, i+1),
new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
assertCompletionListeners(task, totalShards, 1 + numSkippedShards, numSkippedShards, 0, true);
}
task.getSearchProgressActionListener().onFinalReduce(shards,
new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
int numFetchFailures = randomIntBetween(1, numShards - 1);
ShardSearchFailure[] shardSearchFailures = new ShardSearchFailure[numFetchFailures];
for (int i = 0; i < numFetchFailures; i++) {
IOException failure = new IOException("boum");
task.getSearchProgressActionListener().onFetchFailure(i,
new SearchShardTarget("0", new ShardId("0", "0", 1), null, OriginalIndices.NONE),
failure);
shardSearchFailures[i] = new ShardSearchFailure(failure);
}
assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, numFetchFailures, true);
((AsyncSearchTask.Listener)task.getProgressListener()).onResponse(
newSearchResponse(totalShards, totalShards - numFetchFailures, numSkippedShards, shardSearchFailures));
assertCompletionListeners(task, totalShards, totalShards - numFetchFailures, numSkippedShards, numFetchFailures, false);
}
public void testFatalFailureDuringFetch() throws InterruptedException {
AsyncSearchTask task = createAsyncSearchTask(); AsyncSearchTask task = createAsyncSearchTask();
int numShards = randomIntBetween(0, 10); int numShards = randomIntBetween(0, 10);
List<SearchShard> shards = new ArrayList<>(); List<SearchShard> shards = new ArrayList<>();
@ -163,27 +197,23 @@ public class AsyncSearchTaskTests extends ESTestCase {
for (int i = 0; i < numSkippedShards; i++) { for (int i = 0; i < numSkippedShards; i++) {
skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1))); skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1)));
} }
int totalShards = numShards + numSkippedShards;
task.getSearchProgressActionListener().onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false); task.getSearchProgressActionListener().onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false);
for (int i = 0; i < numShards; i++) { for (int i = 0; i < numShards; i++) {
task.getSearchProgressActionListener().onPartialReduce(shards.subList(i, i+1), task.getSearchProgressActionListener().onPartialReduce(shards.subList(0, i+1),
new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0); new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
assertCompletionListeners(task, numShards+numSkippedShards, numSkippedShards, 0, true); assertCompletionListeners(task, totalShards, i + 1 + numSkippedShards, numSkippedShards, 0, true);
} }
task.getSearchProgressActionListener().onFinalReduce(shards, task.getSearchProgressActionListener().onFinalReduce(shards,
new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0); new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
int numFetchFailures = randomIntBetween(0, numShards); for (int i = 0; i < numShards; i++) {
ShardSearchFailure[] failures = new ShardSearchFailure[numFetchFailures]; task.getSearchProgressActionListener().onFetchFailure(i,
for (int i = 0; i < numFetchFailures; i++) { new SearchShardTarget("0", new ShardId("0", "0", 1), null, OriginalIndices.NONE),
failures[i] = new ShardSearchFailure(new IOException("boum"), new IOException("boum"));
new SearchShardTarget("0", new ShardId("0", "0", 1), null, OriginalIndices.NONE));
task.getSearchProgressActionListener().onFetchFailure(i, failures[i].shard(), (Exception) failures[i].getCause());
} }
assertCompletionListeners(task, numShards+numSkippedShards, numSkippedShards, numFetchFailures, true); assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, numShards, true);
((AsyncSearchTask.Listener)task.getProgressListener()).onResponse( ((AsyncSearchTask.Listener)task.getProgressListener()).onFailure(new IOException("boum"));
newSearchResponse(numShards+numSkippedShards, numShards, numSkippedShards, failures)); assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, numShards, true);
assertCompletionListeners(task, numShards+numSkippedShards,
numSkippedShards, numFetchFailures, false);
} }
private static SearchResponse newSearchResponse(int totalShards, int successfulShards, int skippedShards, private static SearchResponse newSearchResponse(int totalShards, int successfulShards, int skippedShards,
@ -194,8 +224,9 @@ public class AsyncSearchTaskTests extends ESTestCase {
100, failures, SearchResponse.Clusters.EMPTY); 100, failures, SearchResponse.Clusters.EMPTY);
} }
private void assertCompletionListeners(AsyncSearchTask task, private static void assertCompletionListeners(AsyncSearchTask task,
int expectedTotalShards, int expectedTotalShards,
int expectedSuccessfulShards,
int expectedSkippedShards, int expectedSkippedShards,
int expectedShardFailures, int expectedShardFailures,
boolean isPartial) throws InterruptedException { boolean isPartial) throws InterruptedException {
@ -206,6 +237,7 @@ public class AsyncSearchTaskTests extends ESTestCase {
@Override @Override
public void onResponse(AsyncSearchResponse resp) { public void onResponse(AsyncSearchResponse resp) {
assertThat(resp.getSearchResponse().getTotalShards(), equalTo(expectedTotalShards)); assertThat(resp.getSearchResponse().getTotalShards(), equalTo(expectedTotalShards));
assertThat(resp.getSearchResponse().getSuccessfulShards(), equalTo(expectedSuccessfulShards));
assertThat(resp.getSearchResponse().getSkippedShards(), equalTo(expectedSkippedShards)); assertThat(resp.getSearchResponse().getSkippedShards(), equalTo(expectedSkippedShards));
assertThat(resp.getSearchResponse().getFailedShards(), equalTo(expectedShardFailures)); assertThat(resp.getSearchResponse().getFailedShards(), equalTo(expectedShardFailures));
assertThat(resp.isPartial(), equalTo(isPartial)); assertThat(resp.isPartial(), equalTo(isPartial));